package hivemall.smile.utils;

import hivemall.math.matrix.ColumnMajorMatrix;
import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.MatrixUtils;
import hivemall.math.matrix.ints.ColumnMajorDenseIntMatrix2d;
import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.VectorProcedure;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.data.Attribute;
import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import smile.data.Attribute;
import smile.math.Math;
import smile.sort.QuickSort;

/* loaded from: input_file:hivemall/smile/utils/SmileExtUtils.class */
public final class SmileExtUtils {

    /* renamed from: hivemall.smile.utils.SmileExtUtils$4, reason: invalid class name */
    /* loaded from: input_file:hivemall/smile/utils/SmileExtUtils$4.class */
    static /* synthetic */ class AnonymousClass4 {
        static final /* synthetic */ int[] $SwitchMap$smile$data$Attribute$Type = new int[Attribute.Type.values().length];

        static {
            try {
                $SwitchMap$smile$data$Attribute$Type[Attribute.Type.NOMINAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$smile$data$Attribute$Type[Attribute.Type.NUMERIC.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private SmileExtUtils() {
    }

    @Nullable
    public static hivemall.smile.data.Attribute[] resolveAttributes(@Nullable String str) throws UDFArgumentException {
        if (str == null) {
            return null;
        }
        String[] split = str.split(",");
        int length = split.length;
        Attribute.NumericAttribute numericAttribute = new Attribute.NumericAttribute();
        hivemall.smile.data.Attribute[] attributeArr = new hivemall.smile.data.Attribute[length];
        for (int i = 0; i < length; i++) {
            String str2 = split[i];
            if ("Q".equals(str2)) {
                attributeArr[i] = numericAttribute;
            } else {
                if (!"C".equals(str2)) {
                    throw new UDFArgumentException("Unexpected type: " + str2);
                }
                attributeArr[i] = new Attribute.NominalAttribute();
            }
        }
        return attributeArr;
    }

    @Nonnull
    public static hivemall.smile.data.Attribute[] attributeTypes(@Nullable final hivemall.smile.data.Attribute[] attributeArr, @Nonnull Matrix matrix) {
        int i;
        if (attributeArr == null) {
            hivemall.smile.data.Attribute[] attributeArr2 = new hivemall.smile.data.Attribute[matrix.numColumns()];
            Arrays.fill(attributeArr2, new Attribute.NumericAttribute());
            return attributeArr2;
        }
        if (matrix.isRowMajorMatrix()) {
            VectorProcedure vectorProcedure = new VectorProcedure() { // from class: hivemall.smile.utils.SmileExtUtils.1
                @Override // hivemall.math.vector.VectorProcedure
                public void apply(int i2, double d) {
                    int i3;
                    hivemall.smile.data.Attribute attribute = attributeArr[i2];
                    if (attribute.type != Attribute.AttributeType.NOMINAL || (i3 = ((int) d) + 1) <= attribute.getSize()) {
                        return;
                    }
                    attribute.setSize(i3);
                }
            };
            int numRows = matrix.numRows();
            for (int i2 = 0; i2 < numRows; i2++) {
                matrix.eachNonNullInRow(i2, vectorProcedure);
            }
        } else if (matrix.isColumnMajorMatrix()) {
            final MutableInt mutableInt = new MutableInt(0);
            VectorProcedure vectorProcedure2 = new VectorProcedure() { // from class: hivemall.smile.utils.SmileExtUtils.2
                @Override // hivemall.math.vector.VectorProcedure
                public void apply(int i3, double d) {
                    int i4 = (int) d;
                    if (i4 > MutableInt.this.getValue()) {
                        MutableInt.this.setValue(i4);
                    }
                }
            };
            int length = attributeArr.length;
            for (int i3 = 0; i3 < length; i3++) {
                hivemall.smile.data.Attribute attribute = attributeArr[i3];
                if (attribute.type == Attribute.AttributeType.NOMINAL && attribute.getSize() == -1) {
                    mutableInt.setValue(0);
                    matrix.eachNonNullInColumn(i3, vectorProcedure2);
                    attribute.setSize(mutableInt.getValue() + 1);
                }
            }
        } else {
            int length2 = attributeArr.length;
            for (int i4 = 0; i4 < length2; i4++) {
                hivemall.smile.data.Attribute attribute2 = attributeArr[i4];
                if (attribute2.type == Attribute.AttributeType.NOMINAL && attribute2.getSize() == -1) {
                    int i5 = 0;
                    int numRows2 = matrix.numRows();
                    for (int i6 = 0; i6 < numRows2; i6++) {
                        double d = matrix.get(i6, i4, Double.NaN);
                        if (!Double.isNaN(d) && (i = (int) d) > i5) {
                            i5 = i;
                        }
                    }
                    attribute2.setSize(i5 + 1);
                }
            }
        }
        return attributeArr;
    }

    @Nonnull
    public static hivemall.smile.data.Attribute[] convertAttributeTypes(@Nonnull smile.data.Attribute[] attributeArr) {
        int length = attributeArr.length;
        Attribute.NumericAttribute numericAttribute = new Attribute.NumericAttribute();
        hivemall.smile.data.Attribute[] attributeArr2 = new hivemall.smile.data.Attribute[length];
        for (int i = 0; i < length; i++) {
            smile.data.Attribute attribute = attributeArr[i];
            switch (AnonymousClass4.$SwitchMap$smile$data$Attribute$Type[attribute.type.ordinal()]) {
                case SizeOf.BYTE /* 1 */:
                    attributeArr2[i] = new Attribute.NominalAttribute();
                    break;
                case 2:
                    attributeArr2[i] = numericAttribute;
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported type: " + attribute.type);
            }
        }
        return attributeArr2;
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [int[], int[][]] */
    @Nonnull
    public static ColumnMajorIntMatrix sort(@Nonnull hivemall.smile.data.Attribute[] attributeArr, @Nonnull Matrix matrix) {
        int numRows = matrix.numRows();
        int numColumns = matrix.numColumns();
        ?? r0 = new int[numColumns];
        if (matrix.isSparse()) {
            int i = numRows / 10;
            final DoubleArrayList doubleArrayList = new DoubleArrayList(i);
            final IntArrayList intArrayList = new IntArrayList(i);
            VectorProcedure vectorProcedure = new VectorProcedure() { // from class: hivemall.smile.utils.SmileExtUtils.3
                @Override // hivemall.math.vector.VectorProcedure
                public void apply(int i2, double d) {
                    DoubleArrayList.this.add(d);
                    intArrayList.add(i2);
                }
            };
            ColumnMajorMatrix columnMajorMatrix = matrix.toColumnMajorMatrix();
            for (int i2 = 0; i2 < numColumns; i2++) {
                if (attributeArr[i2].type == Attribute.AttributeType.NUMERIC) {
                    columnMajorMatrix.eachNonNullInColumn(i2, vectorProcedure);
                    if (!intArrayList.isEmpty()) {
                        int[] array = intArrayList.toArray();
                        QuickSort.sort(doubleArrayList.array(), array, array.length);
                        r0[i2] = array;
                        doubleArrayList.clear();
                        intArrayList.clear();
                    }
                }
            }
        } else {
            double[] dArr = new double[numRows];
            for (int i3 = 0; i3 < numColumns; i3++) {
                if (attributeArr[i3].type == Attribute.AttributeType.NUMERIC) {
                    for (int i4 = 0; i4 < numRows; i4++) {
                        dArr[i4] = matrix.get(i4, i3);
                    }
                    r0[i3] = QuickSort.sort(dArr);
                }
            }
        }
        return new ColumnMajorDenseIntMatrix2d(r0, numRows);
    }

    @Nonnull
    public static int[] classLabels(@Nonnull int[] iArr) throws HiveException {
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        if (unique.length < 2) {
            throw new HiveException("Only one class.");
        }
        for (int i = 0; i < unique.length; i++) {
            if (unique[i] < 0) {
                throw new HiveException("Negative class label: " + unique[i]);
            }
            if (i > 0 && unique[i] - unique[i - 1] > 1) {
                throw new HiveException("Missing class: " + (unique[i - 1] + 1));
            }
        }
        return unique;
    }

    @Nonnull
    public static DecisionTree.SplitRule resolveSplitRule(@Nullable String str) {
        return "gini".equalsIgnoreCase(str) ? DecisionTree.SplitRule.GINI : "entropy".equalsIgnoreCase(str) ? DecisionTree.SplitRule.ENTROPY : "classification_error".equalsIgnoreCase(str) ? DecisionTree.SplitRule.CLASSIFICATION_ERROR : DecisionTree.SplitRule.GINI;
    }

    public static int computeNumInputVars(float f, @Nonnull Matrix matrix) {
        return f <= 0.0f ? (int) Math.ceil(Math.sqrt(matrix.numColumns())) : (f <= 0.0f || f > 1.0f) ? (int) f : (int) (f * matrix.numColumns());
    }

    public static long generateSeed() {
        return Thread.currentThread().getId() * System.nanoTime();
    }

    public static void shuffle(@Nonnull int[] iArr, @Nonnull PRNG prng) {
        for (int length = iArr.length; length > 1; length--) {
            swap(iArr, length - 1, prng.nextInt(length));
        }
    }

    @Nonnull
    public static Matrix shuffle(@Nonnull Matrix matrix, @Nonnull int[] iArr, long j) {
        int numRows = matrix.numRows();
        if (numRows != iArr.length) {
            throw new IllegalArgumentException("x.length (" + numRows + ") != y.length (" + iArr.length + ')');
        }
        if (j == -1) {
            j = generateSeed();
        }
        PRNG createPRNG = RandomNumberGeneratorFactory.createPRNG(j);
        if (matrix.swappable()) {
            for (int i = numRows; i > 1; i--) {
                int nextInt = createPRNG.nextInt(i);
                int i2 = i - 1;
                matrix.swap(i2, nextInt);
                swap(iArr, i2, nextInt);
            }
            return matrix;
        }
        int[] permutation = MathUtils.permutation(numRows);
        for (int i3 = numRows; i3 > 1; i3--) {
            int nextInt2 = createPRNG.nextInt(i3);
            int i4 = i3 - 1;
            swap(permutation, i4, nextInt2);
            swap(iArr, i4, nextInt2);
        }
        return MatrixUtils.shuffle(matrix, permutation);
    }

    @Nonnull
    public static Matrix shuffle(@Nonnull Matrix matrix, @Nonnull double[] dArr, @Nonnull long j) {
        int numRows = matrix.numRows();
        if (numRows != dArr.length) {
            throw new IllegalArgumentException("x.length (" + numRows + ") != y.length (" + dArr.length + ')');
        }
        if (j == -1) {
            j = generateSeed();
        }
        PRNG createPRNG = RandomNumberGeneratorFactory.createPRNG(j);
        if (matrix.swappable()) {
            for (int i = numRows; i > 1; i--) {
                int nextInt = createPRNG.nextInt(i);
                int i2 = i - 1;
                matrix.swap(i2, nextInt);
                swap(dArr, i2, nextInt);
            }
            return matrix;
        }
        int[] permutation = MathUtils.permutation(numRows);
        for (int i3 = numRows; i3 > 1; i3--) {
            int nextInt2 = createPRNG.nextInt(i3);
            int i4 = i3 - 1;
            swap(permutation, i4, nextInt2);
            swap(dArr, i4, nextInt2);
        }
        return MatrixUtils.shuffle(matrix, permutation);
    }

    private static void swap(int[] iArr, int i, int i2) {
        int i3 = iArr[i];
        iArr[i] = iArr[i2];
        iArr[i2] = i3;
    }

    private static void swap(double[] dArr, int i, int i2) {
        double d = dArr[i];
        dArr[i] = dArr[i2];
        dArr[i2] = d;
    }

    @Nonnull
    public static int[] bagsToSamples(@Nonnull int[] iArr) {
        int i = -1;
        for (int i2 : iArr) {
            if (i2 > i) {
                i = i2;
            }
        }
        return bagsToSamples(iArr, i + 1);
    }

    @Nonnull
    public static int[] bagsToSamples(@Nonnull int[] iArr, int i) {
        int[] iArr2 = new int[i];
        for (int i2 : iArr) {
            iArr2[i2] = iArr2[i2] + 1;
        }
        return iArr2;
    }

    public static boolean containsNumericType(@Nonnull hivemall.smile.data.Attribute[] attributeArr) {
        for (hivemall.smile.data.Attribute attribute : attributeArr) {
            if (attribute.type == Attribute.AttributeType.NUMERIC) {
                return true;
            }
        }
        return false;
    }

    @Nonnull
    public static String resolveFeatureName(int i, @Nullable String[] strArr) {
        if (strArr != null && i < strArr.length) {
            return strArr[i];
        }
        return "feature#" + i;
    }

    @Nonnull
    public static String resolveName(int i, @Nullable String[] strArr) {
        if (strArr != null && i < strArr.length) {
            return strArr[i];
        }
        return String.valueOf(i);
    }

    public static double[] getColorBrew(@Nonnegative int i) {
        Preconditions.checkArgument(i >= 1);
        double d = 360.0d / i;
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = (i2 * d) / 360.0d;
        }
        return dArr;
    }
}
