package com.linkedin.feathr.common.tensor;

import com.linkedin.feathr.common.tensor.scalar.ScalarTensor;
import com.linkedin.feathr.common.tensorbuilder.BulkTensorBuilder;
import com.linkedin.feathr.common.tensorbuilder.DenseTensorBuilderFactory;
import com.linkedin.feathr.common.tensorbuilder.TensorBuilder;
import com.linkedin.feathr.common.tensorbuilder.UniversalTensorBuilderFactory;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/linkedin/feathr/common/tensor/Tensors.class */
public final class Tensors {
    private Tensors() {
    }

    private static BulkTensorBuilder getBulkBuilder(TensorType tensorType, int i) {
        BulkTensorBuilder bulkTensorBuilder = DenseTensorBuilderFactory.INSTANCE.getBulkTensorBuilder(tensorType);
        if (!bulkTensorBuilder.hasVariableCardinality() && bulkTensorBuilder.getStaticCardinality() != i) {
            throw new IllegalArgumentException("The number of values " + i + " is not equal to the size of the type " + bulkTensorBuilder.getStaticCardinality() + ".");
        }
        if (i % bulkTensorBuilder.getStaticCardinality() != 0) {
            throw new IllegalArgumentException("The number of values " + i + " is not a multiple of the static size of the type " + bulkTensorBuilder.getStaticCardinality() + ".");
        }
        return bulkTensorBuilder;
    }

    public static TensorData asScalarTensor(TensorType tensorType, Object obj) {
        if (tensorType.getDimensionTypes().size() > 0) {
            throw new IllegalArgumentException("Scalar tensors cannot have dimensions.");
        }
        return ScalarTensor.wrap(obj, tensorType.getValueType().getRepresentation());
    }

    public static TensorData asDenseTensor(TensorType tensorType, float[] fArr) {
        return getBulkBuilder(tensorType, fArr.length).build(fArr);
    }

    public static TensorData asDenseTensor(TensorType tensorType, int[] iArr) {
        return getBulkBuilder(tensorType, iArr.length).build(iArr);
    }

    public static TensorData asDenseTensor(TensorType tensorType, long[] jArr) {
        return getBulkBuilder(tensorType, jArr.length).build(jArr);
    }

    public static TensorData asDenseTensor(TensorType tensorType, double[] dArr) {
        return getBulkBuilder(tensorType, dArr.length).build(dArr);
    }

    public static TensorData asDenseTensor(TensorType tensorType, List<?> list) {
        return getBulkBuilder(tensorType, list.size()).build(list);
    }

    public static TensorData asSparseTensor(TensorType tensorType, Set<?> set) {
        if (tensorType.getDimensionTypes().size() != 1) {
            throw new IllegalArgumentException("Only one-dimensional tensors can represent sets.");
        }
        TensorBuilder<?> tensorBuilder = UniversalTensorBuilderFactory.INSTANCE.getTensorBuilder(tensorType);
        tensorBuilder.start(set.size());
        Iterator<?> it = set.iterator();
        while (it.hasNext()) {
            tensorBuilder.setValue(0, it.next());
            tensorBuilder.setValue(1, 1);
            tensorBuilder.append();
        }
        return tensorBuilder.build();
    }

    public static TensorData asSparseTensor(TensorType tensorType, Map<?, ?> map) {
        if (tensorType.getDimensionTypes().size() != 1) {
            throw new IllegalArgumentException("Only one-dimensional tensors can represent maps.");
        }
        TensorBuilder<?> tensorBuilder = UniversalTensorBuilderFactory.INSTANCE.getTensorBuilder(tensorType);
        tensorBuilder.start(map.size());
        for (Map.Entry<?, ?> entry : map.entrySet()) {
            tensorBuilder.setValue(0, entry.getKey());
            tensorBuilder.setValue(1, entry.getValue());
            tensorBuilder.append();
        }
        return tensorBuilder.build();
    }
}
