package com.linkedin.feathr.common.tensor;

import com.linkedin.feathr.common.types.PrimitiveType;
import com.linkedin.feathr.common.types.ValueType;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/* loaded from: input_file:com/linkedin/feathr/common/tensor/TensorType.class */
public final class TensorType implements Serializable {
    public static final TensorType EMPTY = new TensorType(PrimitiveType.FLOAT, (List<DimensionType>) Collections.emptyList(), (List<String>) Collections.emptyList());
    private final TensorCategory _tensorCategory;
    private final ValueType _valueType;
    private final List<DimensionType> _dimensionTypes;
    private final List<String> _dimensionNames;
    private volatile Representable[] _columnTypes;

    public TensorType(ValueType valueType, List<DimensionType> list) {
        this(valueType, list, (List<String>) null);
    }

    public TensorType(ValueType valueType, List<DimensionType> list, List<String> list2) {
        this(TensorCategory.SPARSE, valueType, list, list2);
    }

    public TensorType(TensorCategory tensorCategory, ValueType valueType, List<DimensionType> list, List<String> list2) {
        this._columnTypes = null;
        this._tensorCategory = tensorCategory;
        List<String> list3 = list2;
        if (list3 == null) {
            list3 = new ArrayList(list.size());
            Iterator<DimensionType> it = list.iterator();
            while (it.hasNext()) {
                list3.add(it.next().getName());
            }
        } else if (list.size() != list3.size()) {
            throw new IllegalArgumentException("The numbers of dimension types " + list + " and names " + list3 + " have to be equal.");
        }
        this._valueType = valueType;
        this._dimensionTypes = list;
        this._dimensionNames = list3;
    }

    public TensorType(TensorCategory tensorCategory, ValueType valueType, List<DimensionType> list) {
        this(tensorCategory, valueType, list, null);
    }

    public TensorCategory getTensorCategory() {
        return this._tensorCategory;
    }

    public ValueType getValueType() {
        return this._valueType;
    }

    public List<DimensionType> getDimensionTypes() {
        return this._dimensionTypes;
    }

    public List<String> getDimensionNames() {
        return this._dimensionNames;
    }

    public Representable[] getColumnTypes() {
        if (this._columnTypes == null) {
            Representable[] representableArr = new Representable[this._dimensionTypes.size() + 1];
            int i = 0;
            Iterator<DimensionType> it = this._dimensionTypes.iterator();
            while (it.hasNext()) {
                representableArr[i] = it.next().getRepresentation();
                i++;
            }
            representableArr[i] = this._valueType.getRepresentation();
            this._columnTypes = representableArr;
        }
        return this._columnTypes;
    }

    public void setDimensions(WriteableTuple writeableTuple, Object[] objArr) {
        Objects.requireNonNull(writeableTuple);
        Objects.requireNonNull(objArr);
        if (objArr.length != this._dimensionTypes.size()) {
            throw new IllegalArgumentException("Wrong number of dimensions. Got " + objArr.length + ", expected " + this._dimensionTypes.size());
        }
        for (int i = 0; i < objArr.length; i++) {
            this._dimensionTypes.get(i).setDimensionValue(writeableTuple, i, objArr[i]);
        }
    }

    public int[] getShape() {
        int size = this._dimensionTypes.size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = this._dimensionTypes.get(i).getShape();
        }
        return iArr;
    }

    public int getDenseSize() {
        int i = 1;
        for (int i2 : getShape()) {
            if (i2 == -1) {
                return -1;
            }
            i *= i2;
        }
        return i;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        TensorType tensorType = (TensorType) obj;
        return Objects.equals(this._tensorCategory, tensorType._tensorCategory) && Objects.equals(this._valueType, tensorType._valueType) && Objects.equals(this._dimensionNames, tensorType._dimensionNames) && Objects.equals(this._dimensionTypes, tensorType._dimensionTypes);
    }

    public int hashCode() {
        return Objects.hash(this._tensorCategory, this._valueType, this._dimensionNames, this._dimensionTypes);
    }

    public String toString() {
        return "TENSOR<" + getTensorCategory() + ">" + ((String) getDimensionTypes().stream().map(dimensionType -> {
            return "[" + dimensionType.toString() + "]";
        }).collect(Collectors.joining())) + ":" + getValueType();
    }
}
