package com.linkedin.feathr.common.tensor;

import com.google.common.collect.Lists;
import com.linkedin.feathr.common.types.PrimitiveType;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:com/linkedin/feathr/common/tensor/TensorTypes.class */
public final class TensorTypes {
    private static final Pattern CAPTURING_DIMENSION_TYPE = Pattern.compile("\\[(?<base>INT|LONG|STRING)(?:\\((?<shape>\\d+)\\))?]");
    private static final Pattern TENSOR_TYPE = Pattern.compile("TENSOR<(?<category>SPARSE|DENSE|RAGGED)>(?<dimensions>(?:\\[(?:INT|LONG|STRING|BYTES)(?:\\(\\d+\\))?])*):(?<value>INT|LONG|FLOAT|DOUBLE|STRING|BOOLEAN|BYTES)");

    private TensorTypes() {
    }

    public static TensorType parseTensorType(String str) {
        Matcher matcher = TENSOR_TYPE.matcher(str);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Not a valid tensor type: " + str);
        }
        return new TensorType(TensorCategory.valueOf(matcher.group("category")), new PrimitiveType(Primitive.valueOf(matcher.group("value"))), parseDimensions(matcher.group("dimensions")), null);
    }

    public static TensorType fromRepresentables(boolean z, Representable[] representableArr) {
        PrimitiveType primitiveType = new PrimitiveType(representableArr[representableArr.length - 1].getRepresentation());
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(representableArr.length - 1);
        for (int i = 0; i < representableArr.length - 1; i++) {
            newArrayListWithCapacity.add(new PrimitiveDimensionType(representableArr[i].getRepresentation()));
        }
        return new TensorType(z ? TensorCategory.SPARSE : TensorCategory.DENSE, primitiveType, newArrayListWithCapacity);
    }

    public static TensorType fromTensorData(TensorData tensorData) {
        return fromRepresentables(!(tensorData instanceof DenseTensor), tensorData.getTypes());
    }

    private static List<DimensionType> parseDimensions(String str) {
        ArrayList arrayList = new ArrayList();
        if (str != null) {
            Matcher matcher = CAPTURING_DIMENSION_TYPE.matcher(str);
            while (matcher.find()) {
                PrimitiveDimensionType primitiveDimensionType = new PrimitiveDimensionType(Primitive.valueOf(matcher.group("base")));
                String group = matcher.group("shape");
                arrayList.add(group == null ? primitiveDimensionType : primitiveDimensionType.withShape(Integer.parseInt(group)));
            }
        }
        return arrayList;
    }
}
