package hex.tree.drf;

import hex.genmodel.GenModel;
import hex.tree.CompressedForest;
import hex.tree.CompressedTree;
import hex.tree.SharedTree;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Iced;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.ModelUtils;
import water.util.RandomUtils;

/* loaded from: input_file:hex/tree/drf/TreeMeasuresCollector.class */
public class TreeMeasuresCollector extends MRTask<TreeMeasuresCollector> {
    private final CompressedForest _cforest;
    private final float _rate;
    private final int _var;
    private final boolean _oob;
    private final int _ncols;
    private final int _nclasses;
    private final boolean _classification;
    private final double _threshold;
    private final SharedTree _st;
    private final int _ntrees;
    private double[] _votes;
    private double[] _nrows;
    private float[] _sse;
    private transient CompressedForest.LocalCompressedForest _forest;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/drf/TreeMeasuresCollector$DummyRandom.class */
    public static final class DummyRandom extends Random {
        private DummyRandom() {
        }

        @Override // java.util.Random
        public final float nextFloat() {
            return 1.0f;
        }
    }

    /* loaded from: input_file:hex/tree/drf/TreeMeasuresCollector$ShuffleTask.class */
    public static class ShuffleTask extends MRTask<ShuffleTask> {
        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            if (chunk._len == 0) {
                return;
            }
            Random rng = RandomUtils.getRNG(seed(chunk.cidx()));
            chunk2.set(0, chunk.atd(0));
            for (int i = 1; i < chunk._len; i++) {
                int nextInt = rng.nextInt(i + 1);
                if (nextInt != i) {
                    chunk2.set(i, chunk2.atd(nextInt));
                }
                chunk2.set(nextInt, chunk.atd(i));
            }
        }

        public static long seed(int i) {
            return (-2291796408025514455L) + (i << 32);
        }

        public static Vec shuffle(Vec vec) {
            Vec makeZero = vec.makeZero();
            new ShuffleTask().doAll(vec, makeZero);
            return makeZero;
        }
    }

    /* loaded from: input_file:hex/tree/drf/TreeMeasuresCollector$TreeMeasures.class */
    public static abstract class TreeMeasures<T extends TreeMeasures> extends Iced {
        protected int _ntrees;
        protected double[] _nrows;

        public TreeMeasures(int i) {
            this._nrows = new double[i];
        }

        public TreeMeasures(double[] dArr, int i) {
            this._nrows = dArr;
            this._ntrees = i;
        }

        public final double[] nrows() {
            return this._nrows;
        }

        public final int npredictors() {
            return this._ntrees;
        }

        public abstract double accuracy(int i);

        public final double[] accuracy() {
            double[] dArr = new double[this._ntrees];
            for (int i = 0; i < this._ntrees; i++) {
                dArr[i] = accuracy(i);
            }
            return dArr;
        }

        public abstract double[] imp(T t);

        public abstract T append(T t);
    }

    /* loaded from: input_file:hex/tree/drf/TreeMeasuresCollector$TreeSSE.class */
    public static class TreeSSE extends TreeMeasures<TreeSSE> {
        private float[] _sse;
        static final /* synthetic */ boolean $assertionsDisabled;

        public TreeSSE(int i) {
            super(i);
            this._sse = new float[i];
        }

        public TreeSSE(float[] fArr, double[] dArr, int i) {
            super(dArr, i);
            this._sse = fArr;
        }

        @Override // hex.tree.drf.TreeMeasuresCollector.TreeMeasures
        public double accuracy(int i) {
            return this._sse[i] / this._nrows[i];
        }

        @Override // hex.tree.drf.TreeMeasuresCollector.TreeMeasures
        public double[] imp(TreeSSE treeSSE) {
            if (!$assertionsDisabled && npredictors() != treeSSE.npredictors()) {
                throw new AssertionError();
            }
            int npredictors = npredictors();
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i = 0; i < npredictors; i++) {
                if (!$assertionsDisabled && treeSSE.nrows()[i] != nrows()[i]) {
                    throw new AssertionError();
                }
                double d3 = (this._sse[i] - treeSSE._sse[i]) / nrows()[i];
                d += d3;
                d2 += d3 * d3;
            }
            double d4 = d / npredictors;
            return new double[]{d4, Math.sqrt(((d2 / npredictors) - (d4 * d4)) / npredictors)};
        }

        @Override // hex.tree.drf.TreeMeasuresCollector.TreeMeasures
        public TreeSSE append(TreeSSE treeSSE) {
            for (int i = 0; i < treeSSE.npredictors(); i++) {
                append(treeSSE._sse[i], treeSSE._nrows[i]);
            }
            return this;
        }

        public TreeSSE append(float f, double d) {
            if (!$assertionsDisabled && (this._sse.length <= this._ntrees || this._sse.length != this._nrows.length)) {
                throw new AssertionError("TreeVotes inconsistency!");
            }
            this._sse[this._ntrees] = f;
            this._nrows[this._ntrees] = d;
            this._ntrees++;
            return this;
        }

        static {
            $assertionsDisabled = !TreeMeasuresCollector.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/tree/drf/TreeMeasuresCollector$TreeVotes.class */
    public static class TreeVotes extends TreeMeasures<TreeVotes> {
        private double[] _votes;
        static final /* synthetic */ boolean $assertionsDisabled;

        public TreeVotes(int i) {
            super(i);
            this._votes = new double[i];
        }

        public TreeVotes(double[] dArr, double[] dArr2, int i) {
            super(dArr2, i);
            this._votes = dArr;
        }

        public final double[] votes() {
            return this._votes;
        }

        @Override // hex.tree.drf.TreeMeasuresCollector.TreeMeasures
        public final double accuracy(int i) {
            if ($assertionsDisabled || (i < this._nrows.length && i < this._votes.length)) {
                return this._votes[i] / this._nrows[i];
            }
            throw new AssertionError();
        }

        @Override // hex.tree.drf.TreeMeasuresCollector.TreeMeasures
        public final double[] imp(TreeVotes treeVotes) {
            if (!$assertionsDisabled && npredictors() != treeVotes.npredictors()) {
                throw new AssertionError();
            }
            int npredictors = npredictors();
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i = 0; i < npredictors; i++) {
                if (!$assertionsDisabled && treeVotes.nrows()[i] != nrows()[i]) {
                    throw new AssertionError();
                }
                double d3 = (treeVotes.votes()[i] - votes()[i]) / nrows()[i];
                d += d3;
                d2 += d3 * d3;
            }
            double d4 = d / npredictors;
            return new double[]{d4, Math.sqrt(((d2 / npredictors) - (d4 * d4)) / npredictors)};
        }

        public TreeVotes append(double d, double d2) {
            if (!$assertionsDisabled && (this._votes.length <= this._ntrees || this._votes.length != this._nrows.length)) {
                throw new AssertionError("TreeVotes inconsistency!");
            }
            this._votes[this._ntrees] = d;
            this._nrows[this._ntrees] = d2;
            this._ntrees++;
            return this;
        }

        @Override // hex.tree.drf.TreeMeasuresCollector.TreeMeasures
        public TreeVotes append(TreeVotes treeVotes) {
            for (int i = 0; i < treeVotes.npredictors(); i++) {
                append(treeVotes._votes[i], treeVotes._nrows[i]);
            }
            return this;
        }

        static {
            $assertionsDisabled = !TreeMeasuresCollector.class.desiredAssertionStatus();
        }
    }

    private TreeMeasuresCollector(CompressedForest compressedForest, int i, int i2, float f, int i3, double d, SharedTree sharedTree) {
        if (!$assertionsDisabled && compressedForest._treeKeys.length <= 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && i != compressedForest._treeKeys[0].length) {
            throw new AssertionError();
        }
        this._cforest = compressedForest;
        this._ncols = i2;
        this._rate = f;
        this._var = i3;
        this._oob = true;
        this._ntrees = compressedForest._treeKeys.length;
        this._nclasses = i;
        this._classification = i > 1;
        this._threshold = d;
        this._st = sharedTree;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.MRTask
    public void setupLocal() {
        this._forest = this._cforest.fetch();
    }

    @Override // water.MRTask
    public void map(Chunk[] chunkArr) {
        double[] dArr = new double[this._ncols];
        double[] dArr2 = new double[this._nclasses + 1];
        Chunk chk_resp = this._st.chk_resp(chunkArr);
        Chunk chk_weight = this._st.hasWeightCol() ? this._st.chk_weight(chunkArr) : new C0DChunk(1.0d, chunkArr[0]._len);
        int i = chk_resp._len;
        int[] iArr = new int[2 + Math.round(((1.0f - this._rate) * i * 1.2f) + 0.5f)];
        int[] iArr2 = null;
        this._nrows = new double[this._ntrees];
        this._votes = this._classification ? new double[this._ntrees] : null;
        this._sse = this._classification ? null : new float[this._ntrees];
        long seed = ShuffleTask.seed(chk_resp.cidx());
        for (int i2 = 0; i2 < this._ntrees; i2++) {
            iArr = ModelUtils.sampleOOBRows(i, this._rate, rngForTree(this._forest._trees[i2], chk_resp.cidx()), iArr);
            int i3 = iArr[0];
            if (this._var >= 0) {
                if (iArr2 == null || iArr2.length < i3) {
                    iArr2 = new int[i3];
                }
                ArrayUtils.shuffleArray(iArr, i3, iArr2, seed, 1);
            }
            for (int i4 = 1; i4 < 1 + i3; i4++) {
                int i5 = iArr[i4];
                double atd = chk_weight.atd(i5);
                if (!chk_resp.isNA(i5) && atd != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    for (int i6 = 0; i6 < this._ncols; i6++) {
                        dArr[i6] = chunkArr[i6].atd(i5);
                    }
                    if (this._var >= 0) {
                        dArr[this._var] = chunkArr[this._var].atd(iArr2[i4 - 1]);
                    } else if (!$assertionsDisabled && iArr2 != null) {
                        throw new AssertionError();
                    }
                    Arrays.fill(dArr2, CMAESOptimizer.DEFAULT_STOPFITNESS);
                    this._forest.scoreTree(dArr, dArr2, i2);
                    if (!this._classification) {
                        double d = dArr2[0];
                        double atd2 = chk_resp.atd(i5);
                        this._sse[i2] = (float) (r0[r1] + ((atd2 - d) * (atd2 - d)));
                    } else if (GenModel.getPrediction(dArr2, null, dArr, this._threshold) == ((int) chk_resp.at8(i5))) {
                        double[] dArr3 = this._votes;
                        int i7 = i2;
                        dArr3[i7] = dArr3[i7] + atd;
                    }
                    double[] dArr4 = this._nrows;
                    int i8 = i2;
                    dArr4[i8] = dArr4[i8] + atd;
                }
            }
        }
    }

    @Override // water.MRTask
    public void reduce(TreeMeasuresCollector treeMeasuresCollector) {
        ArrayUtils.add(this._votes, treeMeasuresCollector._votes);
        ArrayUtils.add(this._nrows, treeMeasuresCollector._nrows);
        ArrayUtils.add(this._sse, treeMeasuresCollector._sse);
    }

    public TreeVotes resultVotes() {
        return new TreeVotes(this._votes, this._nrows, this._ntrees);
    }

    public TreeSSE resultSSE() {
        return new TreeSSE(this._sse, this._nrows, this._ntrees);
    }

    private Random rngForTree(CompressedTree[] compressedTreeArr, int i) {
        return this._oob ? compressedTreeArr[0].rngForChunk(i) : new DummyRandom();
    }

    public static TreeVotes asVotes(TreeMeasures treeMeasures) {
        return (TreeVotes) treeMeasures;
    }

    public static TreeSSE asSSE(TreeMeasures treeMeasures) {
        return (TreeSSE) treeMeasures;
    }

    static {
        $assertionsDisabled = !TreeMeasuresCollector.class.desiredAssertionStatus();
    }
}
