package hivemall.smile.regression;

import hivemall.annotations.VisibleForTesting;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.VariableOrder;
import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.function.Consumer;
import hivemall.utils.function.IntPredicate;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import hivemall.utils.sampling.IntReservoirSampler;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.matrix.Matrix;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import matrix4j.vector.VectorProcedure;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.IntConsumer;
import org.roaringbitmap.RoaringBitmap;
import smile.math.Math;
import smile.regression.Regression;

/* loaded from: input_file:hivemall/smile/regression/RegressionTree.class */
public final class RegressionTree implements Regression<Vector> {
    private static final Log logger = LogFactory.getLog(RegressionTree.class);
    private final Matrix _X;
    private final double[] _y;

    @Nonnull
    private final int[] _samples;

    @Nonnull
    private final VariableOrder _order;

    @Nonnull
    private final int[] _sampleIndex;

    @Nonnull
    private final RoaringBitmap _nominalAttrs;
    private final Vector _importance;
    private final Node _root;
    private final int _maxDepth;
    private final int _minSamplesSplit;
    private final int _minSamplesLeaf;
    private final int _numVars;
    private final PRNG _rnd;

    /* loaded from: input_file:hivemall/smile/regression/RegressionTree$Node.class */
    public static final class Node implements Externalizable {
        double output;
        int splitFeature;
        boolean quantitativeFeature;
        double splitValue;
        double splitScore;
        Node trueChild;
        Node falseChild;
        double trueChildOutput;
        double falseChildOutput;

        public Node() {
            this.output = 0.0d;
            this.splitFeature = -1;
            this.quantitativeFeature = true;
            this.splitValue = Double.NaN;
            this.splitScore = 0.0d;
            this.trueChildOutput = 0.0d;
            this.falseChildOutput = 0.0d;
        }

        public Node(double d) {
            this.output = 0.0d;
            this.splitFeature = -1;
            this.quantitativeFeature = true;
            this.splitValue = Double.NaN;
            this.splitScore = 0.0d;
            this.trueChildOutput = 0.0d;
            this.falseChildOutput = 0.0d;
            this.output = d;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean isLeaf() {
            return this.trueChild == null && this.falseChild == null;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void markAsLeaf() {
            this.splitFeature = -1;
            this.splitValue = Double.NaN;
            this.splitScore = 0.0d;
            this.trueChild = null;
            this.falseChild = null;
        }

        @VisibleForTesting
        public double predict(@Nonnull double[] dArr) {
            return predict((Vector) new DenseVector(dArr));
        }

        public double predict(@Nonnull Vector vector) {
            return isLeaf() ? this.output : this.quantitativeFeature ? vector.get(this.splitFeature, Double.NaN) <= this.splitValue ? this.trueChild.predict(vector) : this.falseChild.predict(vector) : vector.get(this.splitFeature, Double.NaN) == this.splitValue ? this.trueChild.predict(vector) : this.falseChild.predict(vector);
        }

        public double predict(@Nonnull Vector vector, @Nonnull PredictionHandler predictionHandler) {
            if (isLeaf()) {
                predictionHandler.visitLeaf(this.output);
                return this.output;
            }
            double d = vector.get(this.splitFeature, Double.NaN);
            if (this.quantitativeFeature) {
                if (d <= this.splitValue) {
                    predictionHandler.visitBranch(PredictionHandler.Operator.LE, this.splitFeature, d, this.splitValue);
                    return this.trueChild.predict(vector);
                }
                predictionHandler.visitBranch(PredictionHandler.Operator.GT, this.splitFeature, d, this.splitValue);
                return this.falseChild.predict(vector);
            }
            if (d == this.splitValue) {
                predictionHandler.visitBranch(PredictionHandler.Operator.EQ, this.splitFeature, d, this.splitValue);
                return this.trueChild.predict(vector);
            }
            predictionHandler.visitBranch(PredictionHandler.Operator.NE, this.splitFeature, d, this.splitValue);
            return this.falseChild.predict(vector);
        }

        public double predict(int[] iArr) {
            return isLeaf() ? this.output : iArr[this.splitFeature] == ((int) this.splitValue) ? this.trueChild.predict(iArr) : this.falseChild.predict(iArr);
        }

        public void exportJavascript(@Nonnull StringBuilder sb, @Nullable String[] strArr, int i) {
            if (isLeaf()) {
                RegressionTree.indent(sb, i);
                sb.append(this.output).append(";\n");
                return;
            }
            if (this.quantitativeFeature) {
                RegressionTree.indent(sb, i);
                if (strArr == null) {
                    sb.append("if( x[").append(this.splitFeature).append("] <= ").append(this.splitValue).append(") {\n");
                } else {
                    sb.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, strArr)).append(" <= ").append(this.splitValue).append(") {\n");
                }
                this.trueChild.exportJavascript(sb, strArr, i + 1);
                RegressionTree.indent(sb, i);
                sb.append("} else {\n");
                this.falseChild.exportJavascript(sb, strArr, i + 1);
                RegressionTree.indent(sb, i);
                sb.append("}\n");
                return;
            }
            RegressionTree.indent(sb, i);
            if (strArr == null) {
                sb.append("if( x[").append(this.splitFeature).append("] == ").append(this.splitValue).append(") {\n");
            } else {
                sb.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, strArr)).append(" == ").append(this.splitValue).append(") {\n");
            }
            this.trueChild.exportJavascript(sb, strArr, i + 1);
            RegressionTree.indent(sb, i);
            sb.append("} else {\n");
            this.falseChild.exportJavascript(sb, strArr, i + 1);
            RegressionTree.indent(sb, i);
            sb.append("}\n");
        }

        public void exportGraphviz(@Nonnull StringBuilder sb, @Nullable String[] strArr, @Nonnull String str, @Nonnull MutableInt mutableInt, int i) {
            int value = mutableInt.getValue();
            if (isLeaf()) {
                sb.append(String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\", shape=ellipse];\n", Integer.valueOf(value), str, Double.toString(this.output)));
                if (value != i) {
                    sb.append(' ').append(i).append(" -> ").append(value);
                    if (i == 0) {
                        if (value == 1) {
                            sb.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                        } else {
                            sb.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                        }
                    }
                    sb.append(";\n");
                    return;
                }
                return;
            }
            if (this.quantitativeFeature) {
                sb.append(String.format(" %d [label=<%s &le; %s>, fillcolor=\"#00000000\"];\n", Integer.valueOf(value), SmileExtUtils.resolveFeatureName(this.splitFeature, strArr), Double.toString(this.splitValue)));
            } else {
                sb.append(String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", Integer.valueOf(value), SmileExtUtils.resolveFeatureName(this.splitFeature, strArr), Double.toString(this.splitValue)));
            }
            if (value != i) {
                sb.append(' ').append(i).append(" -> ").append(value);
                if (i == 0) {
                    if (value == 1) {
                        sb.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                    } else {
                        sb.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                    }
                }
                sb.append(";\n");
            }
            mutableInt.addValue(1);
            this.trueChild.exportGraphviz(sb, strArr, str, mutableInt, value);
            mutableInt.addValue(1);
            this.falseChild.exportGraphviz(sb, strArr, str, mutableInt, value);
        }

        @Deprecated
        public int opCodegen(@Nonnull List<String> list, int i) {
            int opCodegen;
            StringBuilder sb = new StringBuilder();
            if (isLeaf()) {
                sb.append("push ").append(this.output);
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("goto last");
                list.add(sb.toString());
                opCodegen = 0 + 2;
            } else if (this.quantitativeFeature) {
                sb.append("push ").append("x[").append(this.splitFeature).append("]");
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("push ").append(this.splitValue);
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("ifle ");
                list.add(sb.toString());
                int i2 = i + 3;
                int opCodegen2 = this.trueChild.opCodegen(list, i2);
                list.set(i2 - 1, "ifle " + String.valueOf(i2 + opCodegen2));
                opCodegen = 0 + 3 + opCodegen2 + this.falseChild.opCodegen(list, i2 + opCodegen2);
            } else {
                sb.append("push ").append("x[").append(this.splitFeature).append("]");
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("push ").append(this.splitValue);
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("ifeq ");
                list.add(sb.toString());
                int i3 = i + 3;
                int opCodegen3 = this.trueChild.opCodegen(list, i3);
                list.set(i3 - 1, "ifeq " + String.valueOf(i3 + opCodegen3));
                opCodegen = 0 + 3 + opCodegen3 + this.falseChild.opCodegen(list, i3 + opCodegen3);
            }
            return opCodegen;
        }

        @Override // java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeInt(this.splitFeature);
            objectOutput.writeByte(this.quantitativeFeature ? 1 : 2);
            objectOutput.writeDouble(this.splitValue);
            if (isLeaf()) {
                objectOutput.writeBoolean(true);
                objectOutput.writeDouble(this.output);
                return;
            }
            objectOutput.writeBoolean(false);
            if (this.trueChild == null) {
                objectOutput.writeBoolean(false);
            } else {
                objectOutput.writeBoolean(true);
                this.trueChild.writeExternal(objectOutput);
            }
            if (this.falseChild == null) {
                objectOutput.writeBoolean(false);
            } else {
                objectOutput.writeBoolean(true);
                this.falseChild.writeExternal(objectOutput);
            }
        }

        @Override // java.io.Externalizable
        public void readExternal(ObjectInput objectInput) throws IOException, ClassNotFoundException {
            this.splitFeature = objectInput.readInt();
            this.quantitativeFeature = objectInput.readByte() == 1;
            this.splitValue = objectInput.readDouble();
            if (objectInput.readBoolean()) {
                this.output = objectInput.readDouble();
                return;
            }
            if (objectInput.readBoolean()) {
                this.trueChild = new Node();
                this.trueChild.readExternal(objectInput);
            }
            if (objectInput.readBoolean()) {
                this.falseChild = new Node();
                this.falseChild.readExternal(objectInput);
            }
        }
    }

    /* loaded from: input_file:hivemall/smile/regression/RegressionTree$NodeOutput.class */
    public interface NodeOutput {
        double calculate(int[] iArr);
    }

    /* loaded from: input_file:hivemall/smile/regression/RegressionTree$TrainNode.class */
    private final class TrainNode implements Comparable<TrainNode> {

        @Nonnull
        final Node node;
        final int depth;
        final int low;
        final int high;
        final int samples;

        @Nullable
        TrainNode trueChild;

        @Nullable
        TrainNode falseChild;

        @Nullable
        int[] constFeatures;

        public TrainNode(@Nonnull RegressionTree regressionTree, Node node, int i, int i2, int i3, int i4) {
            this(node, i, i2, i3, i4, new int[0]);
        }

        public TrainNode(@Nonnull Node node, int i, int i2, int i3, int i4, @Nonnull int[] iArr) {
            if (i2 >= i3) {
                throw new IllegalArgumentException("Unexpected condition was met. low=" + i2 + ", high=" + i3);
            }
            this.node = node;
            this.depth = i;
            this.low = i2;
            this.high = i3;
            this.samples = i4;
            this.constFeatures = iArr;
        }

        @Override // java.lang.Comparable
        public int compareTo(TrainNode trainNode) {
            return (int) Math.signum(trainNode.node.splitScore - this.node.splitScore);
        }

        public void calculateOutput(NodeOutput nodeOutput) {
            if (this.node.trueChild == null && this.node.falseChild == null) {
                int[] samples = getSamples();
                this.node.output = nodeOutput.calculate(samples);
            } else {
                if (this.trueChild != null) {
                    this.trueChild.calculateOutput(nodeOutput);
                }
                if (this.falseChild != null) {
                    this.falseChild.calculateOutput(nodeOutput);
                }
            }
        }

        @Nonnull
        private int[] getSamples() {
            IntArrayList intArrayList = new IntArrayList(this.high - this.low);
            int[] iArr = RegressionTree.this._sampleIndex;
            int[] iArr2 = RegressionTree.this._samples;
            int i = this.high;
            for (int i2 = this.low; i2 < i; i2++) {
                int i3 = iArr[i2];
                if (iArr2[i3] > 0) {
                    intArrayList.add(i3);
                }
            }
            return intArrayList.toArray(true);
        }

        public boolean findBestSplit() {
            if (this.depth >= RegressionTree.this._maxDepth || this.samples <= RegressionTree.this._minSamplesSplit) {
                return false;
            }
            int[] iArr = this.constFeatures;
            double d = this.node.output * this.samples;
            for (int i : variableIndex()) {
                if (!ArrayUtils.contains(iArr, i)) {
                    Node findBestSplit = findBestSplit(this.samples, d, i);
                    if (findBestSplit.splitScore > this.node.splitScore) {
                        this.node.splitFeature = findBestSplit.splitFeature;
                        this.node.quantitativeFeature = findBestSplit.quantitativeFeature;
                        this.node.splitValue = findBestSplit.splitValue;
                        this.node.splitScore = findBestSplit.splitScore;
                        this.node.trueChildOutput = findBestSplit.trueChildOutput;
                        this.node.falseChildOutput = findBestSplit.falseChildOutput;
                    }
                }
            }
            return this.node.splitFeature != -1;
        }

        @Nonnull
        private int[] variableIndex() {
            Matrix matrix = RegressionTree.this._X;
            final IntReservoirSampler intReservoirSampler = new IntReservoirSampler(RegressionTree.this._numVars, RegressionTree.this._rnd.nextLong());
            if (matrix.isSparse()) {
                final RoaringBitmap roaringBitmap = new RoaringBitmap();
                VectorProcedure vectorProcedure = new VectorProcedure() { // from class: hivemall.smile.regression.RegressionTree.TrainNode.1
                    public void apply(int i) {
                        roaringBitmap.add(i);
                    }
                };
                int[] iArr = RegressionTree.this._sampleIndex;
                int i = this.high;
                for (int i2 = this.low; i2 < i; i2++) {
                    matrix.eachColumnIndexInRow(iArr[i2], vectorProcedure);
                }
                roaringBitmap.forEach(new IntConsumer() { // from class: hivemall.smile.regression.RegressionTree.TrainNode.2
                    public void accept(int i3) {
                        intReservoirSampler.add(i3);
                    }
                });
            } else {
                int numColumns = matrix.numColumns();
                for (int i3 = 0; i3 < numColumns; i3++) {
                    intReservoirSampler.add(i3);
                }
            }
            return intReservoirSampler.getSample();
        }

        private Node findBestSplit(final int i, final double d, final int i2) {
            final int[] iArr = RegressionTree.this._samples;
            int[] iArr2 = RegressionTree.this._sampleIndex;
            Matrix matrix = RegressionTree.this._X;
            double[] dArr = RegressionTree.this._y;
            final Node node = new Node(0.0d);
            if (RegressionTree.this._nominalAttrs.contains(i2)) {
                Int2DoubleOpenHashMap int2DoubleOpenHashMap = new Int2DoubleOpenHashMap();
                Int2IntOpenHashMap int2IntOpenHashMap = new Int2IntOpenHashMap();
                int i3 = 0;
                int i4 = this.high;
                for (int i5 = this.low; i5 < i4; i5++) {
                    if (iArr[iArr2[i5]] != 0) {
                        double d2 = matrix.get(i5, i2, Double.NaN);
                        if (Double.isNaN(d2)) {
                            i3++;
                        } else {
                            int i6 = (int) d2;
                            int2DoubleOpenHashMap.addTo(i6, dArr[i5]);
                            int2IntOpenHashMap.addTo(i6, 1);
                        }
                    }
                }
                if (int2IntOpenHashMap.size() + (i3 == 0 ? 0 : 1) <= 1) {
                    this.constFeatures = ArrayUtils.sortedArraySet(this.constFeatures, i2);
                }
                ObjectIterator it = int2IntOpenHashMap.int2IntEntrySet().iterator();
                while (it.hasNext()) {
                    Int2IntMap.Entry entry = (Int2IntMap.Entry) it.next();
                    int intKey = entry.getIntKey();
                    double intValue = entry.getIntValue();
                    double d3 = i - intValue;
                    if (intValue >= RegressionTree.this._minSamplesSplit && d3 >= RegressionTree.this._minSamplesSplit) {
                        double d4 = int2DoubleOpenHashMap.get(intKey);
                        double d5 = d4 / intValue;
                        double d6 = (d - d4) / d3;
                        double d7 = (((intValue * d5) * d5) + ((d3 * d6) * d6)) - ((i * node.output) * node.output);
                        if (d7 > node.splitScore) {
                            node.splitFeature = i2;
                            node.quantitativeFeature = false;
                            node.splitValue = intKey;
                            node.splitScore = d7;
                            node.trueChildOutput = d5;
                            node.falseChildOutput = d6;
                        }
                    }
                }
            } else {
                final MutableInt mutableInt = new MutableInt(0);
                final MutableInt mutableInt2 = new MutableInt(0);
                RegressionTree.this._order.eachNonNullInColumn(i2, this.low, this.high, new Consumer() { // from class: hivemall.smile.regression.RegressionTree.TrainNode.3
                    double trueSum = 0.0d;
                    int trueCount = 0;
                    double prevx = Double.NaN;
                    double lastx = Double.NaN;

                    @Override // hivemall.utils.function.Consumer
                    public void accept(int i7, int i8) {
                        int i9 = iArr[i8];
                        if (i9 == 0) {
                            return;
                        }
                        double d8 = RegressionTree.this._X.get(i8, i2, Double.NaN);
                        if (Double.isNaN(d8)) {
                            mutableInt.incr();
                            return;
                        }
                        if (this.lastx != d8) {
                            this.lastx = d8;
                            mutableInt2.incr();
                        }
                        double d9 = RegressionTree.this._y[i8];
                        if (Double.isNaN(this.prevx) || d8 == this.prevx) {
                            this.prevx = d8;
                            this.trueSum += i9 * d9;
                            this.trueCount += i9;
                            return;
                        }
                        double d10 = i - this.trueCount;
                        if (this.trueCount < RegressionTree.this._minSamplesSplit || d10 < RegressionTree.this._minSamplesSplit) {
                            this.prevx = d8;
                            this.trueSum += i9 * d9;
                            this.trueCount += i9;
                            return;
                        }
                        double d11 = this.trueSum / this.trueCount;
                        double d12 = (d - this.trueSum) / d10;
                        double d13 = (((this.trueCount * d11) * d11) + ((d10 * d12) * d12)) - ((i * node.output) * node.output);
                        if (d13 > node.splitScore) {
                            node.splitFeature = i2;
                            node.quantitativeFeature = true;
                            node.splitValue = (d8 + this.prevx) / 2.0d;
                            node.splitScore = d13;
                            node.trueChildOutput = d11;
                            node.falseChildOutput = d12;
                        }
                        this.prevx = d8;
                        this.trueSum += i9 * d9;
                        this.trueCount += i9;
                    }
                });
                if (mutableInt2.get() + (mutableInt.get() == 0 ? 0 : 1) <= 1) {
                    this.constFeatures = ArrayUtils.sortedArraySet(this.constFeatures, i2);
                }
            }
            return node;
        }

        public boolean split(@Nullable PriorityQueue<TrainNode> priorityQueue) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            IntPredicate predicate = getPredicate();
            MutableInt mutableInt = new MutableInt(0);
            MutableInt mutableInt2 = new MutableInt(0);
            int splitSamples = splitSamples(mutableInt, mutableInt2, predicate);
            int i = mutableInt.get();
            int i2 = mutableInt2.get();
            if (i < RegressionTree.this._minSamplesLeaf || i2 < RegressionTree.this._minSamplesLeaf) {
                this.node.markAsLeaf();
                return false;
            }
            partitionOrder(this.low, splitSamples, this.high, predicate);
            int i3 = 0;
            this.node.trueChild = new Node(this.node.trueChildOutput);
            this.trueChild = new TrainNode(this.node.trueChild, this.depth + 1, this.low, splitSamples, i, (int[]) this.constFeatures.clone());
            this.node.falseChild = new Node(this.node.falseChildOutput);
            this.falseChild = new TrainNode(this.node.falseChild, this.depth + 1, splitSamples, this.high, i2, this.constFeatures);
            this.constFeatures = null;
            if (i < RegressionTree.this._minSamplesSplit || !this.trueChild.findBestSplit()) {
                i3 = 0 + 1;
            } else if (priorityQueue != null) {
                priorityQueue.add(this.trueChild);
            } else if (!this.trueChild.split(null)) {
                i3 = 0 + 1;
            }
            if (i2 < RegressionTree.this._minSamplesSplit || !this.falseChild.findBestSplit()) {
                i3++;
            } else if (priorityQueue != null) {
                priorityQueue.add(this.falseChild);
            } else if (!this.falseChild.split(null)) {
                i3++;
            }
            if (i3 == 2 && this.node.trueChild.output == this.node.falseChild.output) {
                this.node.markAsLeaf();
                return false;
            }
            RegressionTree.this._importance.incr(this.node.splitFeature, this.node.splitScore);
            return true;
        }

        private int splitSamples(@Nonnull MutableInt mutableInt, @Nonnull MutableInt mutableInt2, @Nonnull IntPredicate intPredicate) {
            int[] iArr = RegressionTree.this._sampleIndex;
            int[] iArr2 = RegressionTree.this._samples;
            int i = this.low;
            int i2 = this.high;
            for (int i3 = this.low; i3 < i2; i3++) {
                int i4 = iArr[i3];
                int i5 = iArr2[i4];
                if (intPredicate.test(i4)) {
                    mutableInt.addValue(i5);
                    i++;
                } else {
                    mutableInt2.addValue(i5);
                }
            }
            return i;
        }

        private void partitionOrder(final int i, final int i2, final int i3, @Nonnull final IntPredicate intPredicate) {
            final int[] iArr = new int[i3 - i2];
            RegressionTree.this._order.eachRow(new Consumer() { // from class: hivemall.smile.regression.RegressionTree.TrainNode.4
                @Override // hivemall.utils.function.Consumer
                public void accept(int i4, @Nonnull SparseIntArray sparseIntArray) {
                    RegressionTree.partitionArray(sparseIntArray, i, i2, i3, intPredicate, iArr);
                }
            });
            RegressionTree.partitionArray(RegressionTree.this._sampleIndex, i, i2, i3, intPredicate, iArr);
        }

        @Nonnull
        private IntPredicate getPredicate() {
            return this.node.quantitativeFeature ? new IntPredicate() { // from class: hivemall.smile.regression.RegressionTree.TrainNode.5
                @Override // hivemall.utils.function.IntPredicate
                public boolean test(int i) {
                    return RegressionTree.this._X.get(i, TrainNode.this.node.splitFeature, Double.NaN) <= TrainNode.this.node.splitValue;
                }
            } : new IntPredicate() { // from class: hivemall.smile.regression.RegressionTree.TrainNode.6
                @Override // hivemall.utils.function.IntPredicate
                public boolean test(int i) {
                    return RegressionTree.this._X.get(i, TrainNode.this.node.splitFeature, Double.NaN) == TrainNode.this.node.splitValue;
                }
            };
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void indent(StringBuilder sb, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            sb.append("  ");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void partitionArray(@Nonnull SparseIntArray sparseIntArray, int i, int i2, int i3, @Nonnull IntPredicate intPredicate, @Nonnull int[] iArr) {
        int[] keys = sparseIntArray.keys();
        int[] values = sparseIntArray.values();
        int size = sparseIntArray.size();
        int insertionPoint = ArrayUtils.insertionPoint(keys, size, i);
        int insertionPoint2 = ArrayUtils.insertionPoint(keys, size, i3);
        int i4 = insertionPoint;
        int i5 = 0;
        int i6 = 0;
        for (int i7 = insertionPoint; i7 < insertionPoint2; i7++) {
            int i8 = values[i7];
            if (intPredicate.test(i8)) {
                keys[i4] = i + i6;
                values[i4] = i8;
                i4++;
                i6++;
            } else {
                if (i5 >= iArr.length) {
                    throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.size()=%d, buf.length=%d, i=%d, j=%d, k=%d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(sparseIntArray.size()), Integer.valueOf(iArr.length), Integer.valueOf(i7), Integer.valueOf(i6), Integer.valueOf(i5)));
                }
                int i9 = i5;
                i5++;
                iArr[i9] = i8;
            }
        }
        for (int i10 = 0; i10 < i5; i10++) {
            keys[i4] = i2 + i10;
            values[i4] = iArr[i10];
            i4++;
        }
        if (i4 != insertionPoint2) {
            throw new IllegalStateException(String.format("pos=%d, startPos=%d, endPos=%d, k=%d", Integer.valueOf(i4), Integer.valueOf(insertionPoint), Integer.valueOf(insertionPoint2), Integer.valueOf(i5)));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void partitionArray(@Nonnull int[] iArr, int i, int i2, int i3, @Nonnull IntPredicate intPredicate, @Nonnull int[] iArr2) {
        int i4 = i;
        int i5 = 0;
        for (int i6 = i; i6 < i3; i6++) {
            if (i6 >= iArr.length) {
                throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(iArr.length), Integer.valueOf(iArr2.length), Integer.valueOf(i6), Integer.valueOf(i4), Integer.valueOf(i5)));
            }
            int i7 = iArr[i6];
            if (intPredicate.test(i7)) {
                int i8 = i4;
                i4++;
                iArr[i8] = i7;
            } else {
                if (i5 >= iArr2.length) {
                    throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(iArr.length), Integer.valueOf(iArr2.length), Integer.valueOf(i6), Integer.valueOf(i4), Integer.valueOf(i5)));
                }
                int i9 = i5;
                i5++;
                iArr2[i9] = i7;
            }
        }
        if (i5 != i3 - i2 || i4 != i2) {
            throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, j=%d, k=%d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(iArr.length), Integer.valueOf(iArr2.length), Integer.valueOf(i4), Integer.valueOf(i5)));
        }
        System.arraycopy(iArr2, 0, iArr, i2, i5);
    }

    private static void pruneRedundantLeaves(@Nonnull Node node, @Nonnull Vector vector) {
        if (node.isLeaf()) {
            return;
        }
        pruneRedundantLeaves(node.trueChild, vector);
        pruneRedundantLeaves(node.falseChild, vector);
        if (node.trueChild.isLeaf() && node.falseChild.isLeaf() && node.trueChild.output == node.falseChild.output) {
            node.trueChild = null;
            node.falseChild = null;
            vector.decr(node.splitFeature, node.splitScore);
        }
    }

    public RegressionTree(@Nullable RoaringBitmap roaringBitmap, @Nonnull Matrix matrix, @Nonnull double[] dArr, int i) {
        this(roaringBitmap, matrix, dArr, matrix.numColumns(), Integer.MAX_VALUE, i, 5, 1, null, null);
    }

    public RegressionTree(@Nullable RoaringBitmap roaringBitmap, @Nonnull Matrix matrix, @Nonnull double[] dArr, int i, @Nullable PRNG prng) {
        this(roaringBitmap, matrix, dArr, matrix.numColumns(), Integer.MAX_VALUE, i, 5, 1, null, prng);
    }

    public RegressionTree(@Nullable RoaringBitmap roaringBitmap, @Nonnull Matrix matrix, @Nonnull double[] dArr, int i, int i2, int i3, int i4, int i5, @Nullable int[] iArr, @Nullable PRNG prng) {
        this(roaringBitmap, matrix, dArr, i, i2, i3, i4, i5, iArr, null, prng);
    }

    public RegressionTree(@Nullable RoaringBitmap roaringBitmap, @Nonnull Matrix matrix, @Nonnull double[] dArr, int i, int i2, int i3, int i4, int i5, @Nullable int[] iArr, @Nullable NodeOutput nodeOutput, @Nullable PRNG prng) {
        int[] array;
        TrainNode poll;
        checkArgument(matrix, dArr, i, i2, i3, i4, i5);
        this._X = matrix;
        this._y = dArr;
        roaringBitmap = roaringBitmap == null ? new RoaringBitmap() : roaringBitmap;
        this._nominalAttrs = roaringBitmap;
        this._numVars = i;
        this._maxDepth = i2;
        if (i4 < i5 * 2) {
            if (logger.isInfoEnabled()) {
                logger.info(String.format("min_sample_leaf = %d replaces min_sample_split = %d with min_sample_split = %d", Integer.valueOf(i5), Integer.valueOf(i4), Integer.valueOf(i5 * 2)));
            }
            i4 = i5 * 2;
        }
        this._minSamplesSplit = i4;
        this._minSamplesLeaf = i5;
        this._importance = matrix.isSparse() ? new SparseVector() : new DenseVector(matrix.numColumns());
        this._rnd = prng == null ? RandomNumberGeneratorFactory.createPRNG() : prng;
        int i6 = 0;
        double d = 0.0d;
        if (iArr == null) {
            i6 = dArr.length;
            iArr = new int[i6];
            array = new int[i6];
            for (int i7 = 0; i7 < i6; i7++) {
                iArr[i7] = 1;
                d += dArr[i7];
                array[i7] = i7;
            }
        } else {
            IntArrayList intArrayList = new IntArrayList(0);
            int length = dArr.length;
            for (int i8 = 0; i8 < length; i8++) {
                int i9 = iArr[i8];
                if (i9 != 0) {
                    i6 += i9;
                    d += i9 * dArr[i8];
                    intArrayList.add(i8);
                }
            }
            array = intArrayList.toArray(true);
        }
        this._samples = iArr;
        this._order = SmileExtUtils.sort(roaringBitmap, matrix, iArr);
        this._sampleIndex = array;
        this._root = new Node(d / i6);
        TrainNode trainNode = new TrainNode(this, this._root, 1, 0, this._sampleIndex.length, i6);
        if (i3 != Integer.MAX_VALUE) {
            PriorityQueue<TrainNode> priorityQueue = new PriorityQueue<>();
            if (trainNode.findBestSplit()) {
                priorityQueue.add(trainNode);
            }
            int i10 = 1;
            while (i10 < i3 && (poll = priorityQueue.poll()) != null) {
                if (!poll.split(priorityQueue)) {
                    i10--;
                }
                i10++;
            }
            pruneRedundantLeaves(this._root, this._importance);
        } else if (trainNode.findBestSplit()) {
            trainNode.split(null);
        }
        if (nodeOutput != null) {
            trainNode.calculateOutput(nodeOutput);
        }
    }

    private static void checkArgument(@Nonnull Matrix matrix, @Nonnull double[] dArr, int i, int i2, int i3, int i4, int i5) {
        if (matrix.numRows() != dArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(matrix.numRows()), Integer.valueOf(dArr.length)));
        }
        if (dArr.length == 0) {
            throw new IllegalArgumentException("No training example given");
        }
        if (i <= 0 || i > matrix.numColumns()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + i);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("maxDepth should be greater than 1: " + i2);
        }
        if (i3 < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + i3);
        }
        if (i4 < 2) {
            throw new IllegalArgumentException("Invalid minimum number of samples required to split an internal node: " + i4);
        }
        if (i5 < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + i5);
        }
    }

    public Vector importance() {
        return this._importance;
    }

    @VisibleForTesting
    public double predict(@Nonnull double[] dArr) {
        return predict((Vector) new DenseVector(dArr));
    }

    public double predict(@Nonnull Vector vector) {
        return this._root.predict(vector);
    }

    @Nonnull
    public String predictJsCodegen(@Nonnull String[] strArr) {
        StringBuilder sb = new StringBuilder(1024);
        this._root.exportJavascript(sb, strArr, 0);
        return sb.toString();
    }

    @Nonnull
    @Deprecated
    public String predictOpCodegen(@Nonnull String str) {
        ArrayList arrayList = new ArrayList();
        this._root.opCodegen(arrayList, 0);
        arrayList.add("call end");
        return StringUtils.concat(arrayList, str);
    }

    @Nonnull
    public byte[] serialize(boolean z) throws HiveException {
        try {
            return z ? ObjectUtils.toCompressedBytes((Externalizable) this._root) : ObjectUtils.toBytes((Externalizable) this._root);
        } catch (IOException e) {
            throw new HiveException("IOException cause while serializing DecisionTree object", e);
        } catch (Exception e2) {
            throw new HiveException("Exception cause while serializing DecisionTree object", e2);
        }
    }

    @Nonnull
    public static Node deserialize(@Nonnull byte[] bArr, int i, boolean z) throws HiveException {
        Node node = new Node();
        try {
            if (z) {
                ObjectUtils.readCompressedObject(bArr, 0, i, node);
            } else {
                ObjectUtils.readObject(bArr, i, node);
            }
            return node;
        } catch (IOException e) {
            throw new HiveException("IOException cause while deserializing DecisionTree object", e);
        } catch (Exception e2) {
            throw new HiveException("Exception cause while deserializing DecisionTree object", e2);
        }
    }

    public String toString() {
        return this._root == null ? "" : predictJsCodegen(null);
    }
}
