package com.linkedin.feathr.common;

import com.google.common.collect.Maps;
import com.linkedin.feathr.common.tensor.DimensionType;
import com.linkedin.feathr.common.tensor.LOLTensorData;
import com.linkedin.feathr.common.tensor.Primitive;
import com.linkedin.feathr.common.tensor.ReadableTuple;
import com.linkedin.feathr.common.tensor.Representable;
import com.linkedin.feathr.common.tensor.SimpleWriteableTuple;
import com.linkedin.feathr.common.tensor.StandaloneReadableTuple;
import com.linkedin.feathr.common.tensor.TensorData;
import com.linkedin.feathr.common.tensor.TensorIterator;
import com.linkedin.feathr.common.tensor.TensorType;
import com.linkedin.feathr.common.tensor.TensorTypes;
import com.linkedin.feathr.common.tensorbuilder.TensorBuilder;
import com.linkedin.feathr.common.tensorbuilder.TensorBuilderFactory;
import com.linkedin.feathr.common.tensorbuilder.UniversalTensorBuilderFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

/* loaded from: input_file:com/linkedin/feathr/common/TensorUtils.class */
public final class TensorUtils {
    public static final int DEFAULT_MAX_STRING_LEN = 10240;
    private static final String SEPARATOR = ",";
    private static final String NEXT_LINE = "\n";
    private static final String VALUE_DIM_NAME = "Value";
    private static final String EXCEED_MAX_LIMIT = "...";

    private TensorUtils() {
    }

    public static String getDebugString(TensorType tensorType, TensorData tensorData, int i) {
        StringBuilder sb = new StringBuilder();
        Iterator<String> it = tensorType.getDimensionNames().iterator();
        while (it.hasNext()) {
            sb.append(it.next());
            sb.append(SEPARATOR);
        }
        sb.append(VALUE_DIM_NAME);
        ArrayList arrayList = new ArrayList(tensorType.getDimensionTypes());
        arrayList.add(tensorType.getValueType());
        TensorIterator it2 = tensorData.iterator();
        it2.start();
        while (true) {
            if (!it2.isValid()) {
                break;
            }
            sb.append(NEXT_LINE);
            if (sb.length() >= i) {
                sb.append(EXCEED_MAX_LIMIT);
                break;
            }
            sb.append(String.join(SEPARATOR, convertToStrings(tensorType, it2, arrayList.size())));
            it2.next();
        }
        return sb.toString();
    }

    public static TensorData convertNestedMapToTensor(Map<String, Object> map, TensorType tensorType) {
        return convertNestedMapToTensor(map, tensorType, UniversalTensorBuilderFactory.INSTANCE);
    }

    public static TensorData convertNestedMapToTensor(Map<String, Object> map, TensorType tensorType, TensorBuilderFactory tensorBuilderFactory) {
        TensorBuilder<?> tensorBuilder = tensorBuilderFactory.getTensorBuilder(tensorType);
        tensorBuilder.start(map.size());
        populateTensorBuilder(map, tensorType, tensorBuilder, 0, new SimpleWriteableTuple(tensorType.getColumnTypes()));
        return tensorBuilder.build();
    }

    public static Map<ReadableTuple, Float> convertTensorToMap(TensorData tensorData) {
        if (tensorData == null) {
            return null;
        }
        Map convertTensorToMapWithGenericValues = convertTensorToMapWithGenericValues(tensorData);
        Representable[] types = tensorData.getTypes();
        if (types[types.length - 1] == Primitive.FLOAT) {
            return convertTensorToMapWithGenericValues;
        }
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(convertTensorToMapWithGenericValues.size());
        for (Map.Entry entry : convertTensorToMapWithGenericValues.entrySet()) {
            if (entry.getValue() instanceof Number) {
                newHashMapWithExpectedSize.put(entry.getKey(), Float.valueOf(((Number) entry.getValue()).floatValue()));
            } else if (entry.getValue() instanceof Boolean) {
                newHashMapWithExpectedSize.put(entry.getKey(), Float.valueOf(((Boolean) entry.getValue()).booleanValue() ? 1.0f : 0.0f));
            } else {
                if (!(entry.getValue() instanceof String)) {
                    throw new IllegalArgumentException("Expecting Primitive value but received " + entry.getValue().getClass());
                }
                try {
                    newHashMapWithExpectedSize.put(entry.getKey(), Float.valueOf(Float.parseFloat((String) entry.getValue())));
                } catch (NumberFormatException e) {
                    throw new IllegalArgumentException(String.format("String value %s can not be formatted to a float", entry.getValue()), e);
                }
            }
        }
        return newHashMapWithExpectedSize;
    }

    public static Map<ReadableTuple, Object> convertTensorToMapWithGenericValues(TensorData tensorData) {
        if (tensorData == null) {
            return null;
        }
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(tensorData.estimatedCardinality());
        TensorIterator it = tensorData.iterator();
        Representable[] types = it.getTypes();
        int length = types.length - 1;
        Primitive representation = types[length].getRepresentation();
        it.start();
        while (it.isValid()) {
            newHashMapWithExpectedSize.put(new StandaloneReadableTuple(it, true), representation.toObject(it, length));
            it.next();
        }
        return newHashMapWithExpectedSize;
    }

    private static void populateTensorBuilder(Map<String, Object> map, TensorType tensorType, TensorBuilder tensorBuilder, int i, SimpleWriteableTuple simpleWriteableTuple) {
        int length = tensorType.getColumnTypes().length;
        map.forEach((str, obj) -> {
            tensorType.getDimensionTypes().get(i).setDimensionValue(simpleWriteableTuple, i, str);
            if (obj instanceof Map) {
                if (i + 2 >= length) {
                    throw new IllegalArgumentException(String.format("Expected only %d columns, but found more", Integer.valueOf(length)));
                }
                populateTensorBuilder((Map) obj, tensorType, tensorBuilder, i + 1, simpleWriteableTuple);
            } else {
                if (i + 2 != length) {
                    throw new IllegalArgumentException(String.format("Value %s is at depth %d but tensorType suggests it should be at %d", obj.toString(), Integer.valueOf(i), Integer.valueOf(length)));
                }
                tensorType.getValueType().getRepresentation().from(obj, simpleWriteableTuple, i + 1);
                setRow(tensorBuilder, simpleWriteableTuple);
            }
        });
    }

    private static void setRow(TensorBuilder tensorBuilder, SimpleWriteableTuple simpleWriteableTuple) {
        Representable[] types = tensorBuilder.getTypes();
        for (int i = 0; i < simpleWriteableTuple.getTypes().length; i++) {
            types[i].getRepresentation().copy(simpleWriteableTuple, i, tensorBuilder, i);
        }
        tensorBuilder.append();
    }

    public static LOLTensorData convertToLOLTensor(TensorData tensorData) {
        Representable[] types = tensorData.getTypes();
        int length = types.length;
        ArrayList arrayList = new ArrayList(length - 1);
        ArrayList arrayList2 = new ArrayList(tensorData.estimatedCardinality());
        int length2 = types.length - 1;
        Primitive representation = types[length2].getRepresentation();
        for (int i = 0; i < length - 1; i++) {
            arrayList.add(new ArrayList(tensorData.estimatedCardinality()));
        }
        TensorIterator it = tensorData.iterator();
        while (it.isValid()) {
            for (int i2 = 0; i2 < length - 1; i2++) {
                ((List) arrayList.get(i2)).add(it.getValue(i2));
            }
            arrayList2.add(representation.toObject(it, length2));
            it.next();
        }
        return new LOLTensorData(tensorData.getTypes(), arrayList, arrayList2);
    }

    public static String[] convertToStrings(TensorType tensorType, ReadableTuple readableTuple, int i) {
        List<DimensionType> dimensionTypes = tensorType.getDimensionTypes();
        int size = dimensionTypes.size();
        if (i > size + 1) {
            throw new IllegalArgumentException("Number of columns in the output is greater than number of dims & value");
        }
        boolean z = false;
        if (i > size) {
            z = true;
        } else {
            size = i;
        }
        String[] strArr = new String[i];
        for (int i2 = 0; i2 < size; i2++) {
            strArr[i2] = dimensionTypes.get(i2).getDimensionValue(readableTuple, i2).toString();
        }
        if (z) {
            strArr[size] = tensorType.getValueType().getRepresentation().toString(readableTuple, size);
        }
        return strArr;
    }

    public static <K> Function<ReadableTuple, K> wrapKeyGen(TensorType tensorType, Function<String[], K> function) {
        int size = tensorType.getDimensionTypes().size();
        return readableTuple -> {
            return function.apply(convertToStrings(tensorType, readableTuple, size));
        };
    }

    public static long[] getShape(TensorType tensorType) {
        int size = tensorType.getDimensionTypes().size();
        long[] jArr = new long[size];
        for (int i = 0; i < size; i++) {
            jArr[i] = r0.get(i).getShape();
        }
        return jArr;
    }

    public static TensorData populateTensor(Representable[] representableArr, Object[][] objArr, TensorBuilder tensorBuilder) {
        for (int i = 0; i < objArr.length; i++) {
            for (int i2 = 0; i2 < objArr[i].length; i2++) {
                if (objArr[i].length != representableArr.length) {
                    throw new IllegalArgumentException(String.format("data[i] length should be equal to columnType lengthFound data[i].length = %s and columnType.length = %s", Integer.valueOf(objArr[i].length), Integer.valueOf(representableArr.length)));
                }
                representableArr[i2].getRepresentation().from(objArr[i][i2], tensorBuilder, i2);
            }
            tensorBuilder.append();
        }
        return tensorBuilder.build();
    }

    @Deprecated
    public static TensorType parseTensorType(String str) {
        return TensorTypes.parseTensorType(str);
    }

    public static int safeRatio(int i, int i2) {
        if (i2 == 0) {
            if (i == 0) {
                return 0;
            }
            throw new IllegalArgumentException("Dividing a non-zero " + i + " by zero.");
        }
        int i3 = i / i2;
        if (i3 * i2 != i) {
            throw new IllegalArgumentException("Integer division has a non-zero remainder " + i + "/" + i2 + ".");
        }
        return i3;
    }
}
