package hex.tree.gbm;

import hex.DistributionFactory;
import hex.FeatureInteractions;
import hex.FeatureInteractionsCollector;
import hex.FriedmanPopescusHCollector;
import hex.KeyValue;
import hex.Model;
import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.BranchInteractionConstraints;
import hex.tree.CompressedForest;
import hex.tree.CompressedTree;
import hex.tree.Constraints;
import hex.tree.FriedmanPopescusH;
import hex.tree.GlobalInteractionConstraints;
import hex.tree.Score;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModelWithContributions;
import hex.tree.SharedTreePojoWriter;
import hex.util.EffectiveParametersUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/tree/gbm/GBMModel.class */
public class GBMModel extends SharedTreeModelWithContributions<GBMModel, GBMParameters, GBMOutput> implements Model.StagedPredictions, FeatureInteractionsCollector, FriedmanPopescusHCollector {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/gbm/GBMModel$GBMOutput.class */
    public static class GBMOutput extends SharedTreeModel.SharedTreeOutput {
        public String[] _quasibinomialDomains;
        boolean _quasibinomial;
        int _nclasses;

        public int nclasses() {
            return this._nclasses;
        }

        public GBMOutput(GBM gbm) {
            super(gbm);
            this._quasibinomial = ((GBMParameters) gbm._parms)._distribution == DistributionFamily.quasibinomial;
            this._nclasses = gbm.nclasses();
        }

        public String[] classNames() {
            return this._quasibinomial ? this._quasibinomialDomains : super.classNames();
        }
    }

    /* loaded from: input_file:hex/tree/gbm/GBMModel$GBMParameters.class */
    public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
        public double _learn_rate = 0.1d;
        public double _learn_rate_annealing = 1.0d;
        public double _col_sample_rate = 1.0d;
        public double _max_abs_leafnode_pred;
        public double _pred_noise_bandwidth;
        public KeyValue[] _monotone_constraints;
        public String[][] _interaction_constraints;

        public GBMParameters() {
            this._sample_rate = 1.0d;
            this._ntrees = 50;
            this._max_depth = 5;
            this._max_abs_leafnode_pred = Double.MAX_VALUE;
            this._pred_noise_bandwidth = 0.0d;
        }

        @Override // hex.tree.SharedTreeModel.SharedTreeParameters
        public boolean useColSampling() {
            return super.useColSampling() || this._col_sample_rate != 1.0d;
        }

        public String algoName() {
            return "GBM";
        }

        public String fullName() {
            return "Gradient Boosting Machine";
        }

        public String javaName() {
            return GBMModel.class.getName();
        }

        @Override // hex.tree.SharedTreeModel.SharedTreeParameters
        public boolean forceStrictlyReproducibleHistograms() {
            return usesMonotoneConstraints();
        }

        private boolean usesMonotoneConstraints() {
            return (areMonotoneConstraintsEmpty() && emptyConstraints(0) == null) ? false : true;
        }

        private boolean areMonotoneConstraintsEmpty() {
            return this._monotone_constraints == null || this._monotone_constraints.length == 0;
        }

        public Constraints constraints(Frame frame) {
            if (areMonotoneConstraintsEmpty()) {
                return emptyConstraints(frame.numCols());
            }
            int[] iArr = new int[frame.numCols()];
            for (KeyValue keyValue : this._monotone_constraints) {
                if (keyValue.getValue() != 0.0d) {
                    int find = frame.find(keyValue.getKey());
                    if (find < 0) {
                        throw new IllegalStateException("Invalid constraint specification, column '" + keyValue.getKey() + "' doesn't exist.");
                    }
                    iArr[find] = keyValue.getValue() < 0.0d ? -1 : 1;
                }
            }
            return new Constraints(iArr, DistributionFactory.getDistribution(this), this._distribution == DistributionFamily.gaussian || this._distribution == DistributionFamily.bernoulli || this._distribution == DistributionFamily.tweedie || this._distribution == DistributionFamily.quasibinomial || this._distribution == DistributionFamily.multinomial || this._distribution == DistributionFamily.quantile);
        }

        Constraints emptyConstraints(int i) {
            return null;
        }

        public GlobalInteractionConstraints interactionConstraints(Frame frame) {
            return new GlobalInteractionConstraints(this._interaction_constraints, frame.names());
        }

        public BranchInteractionConstraints initialInteractionConstraints(GlobalInteractionConstraints globalInteractionConstraints) {
            return new BranchInteractionConstraints(globalInteractionConstraints.getAllAllowedColumnIndices());
        }
    }

    /* loaded from: input_file:hex/tree/gbm/GBMModel$StagedPredictionsTask.class */
    private static class StagedPredictionsTask extends MRTask<StagedPredictionsTask> {
        private final Key<GBMModel> _modelKey;
        private transient GBMModel _model;
        static final /* synthetic */ boolean $assertionsDisabled;

        private StagedPredictionsTask(GBMModel gBMModel) {
            this._modelKey = gBMModel._key;
        }

        protected void setupLocal() {
            this._model = this._modelKey.get();
            if (!$assertionsDisabled && this._model == null) {
                throw new AssertionError();
            }
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            double[] dArr = new double[chunkArr.length];
            int i = ((GBMOutput) this._model._output).nclasses() == 1 ? 0 : 1;
            for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                for (int i3 = 0; i3 < chunkArr.length; i3++) {
                    dArr[i3] = chunkArr[i3].atd(i2);
                }
                double[] dArr2 = new double[i + ((GBMOutput) this._model._output).nclasses()];
                double[] dArr3 = new double[dArr2.length];
                int i4 = 0;
                for (int i5 = 0; i5 < ((GBMOutput) this._model._output)._treeKeys.length; i5++) {
                    Key<CompressedTree>[] keyArr = ((GBMOutput) this._model._output)._treeKeys[i5];
                    for (int i6 = 0; i6 < keyArr.length; i6++) {
                        if (keyArr[i6] != null) {
                            int i7 = i + i6;
                            dArr2[i7] = dArr2[i7] + DKV.get(keyArr[i6]).get().score(dArr, ((GBMOutput) this._model._output)._domains);
                        }
                        dArr3[i + i6] = dArr2[i + i6];
                    }
                    this._model.score0Probabilities(dArr3, 0.0d);
                    this._model.score0PostProcessSupervised(dArr3, dArr);
                    for (int i8 = 0; i8 < keyArr.length; i8++) {
                        if (keyArr[i8] != null) {
                            int i9 = i4;
                            i4++;
                            newChunkArr[i9].addNum(dArr3[i + i8]);
                        }
                    }
                }
                if (!$assertionsDisabled && i4 != newChunkArr.length) {
                    throw new AssertionError();
                }
            }
        }

        static {
            $assertionsDisabled = !GBMModel.class.desiredAssertionStatus();
        }
    }

    public GBMModel(Key<GBMModel> key, GBMParameters gBMParameters, GBMOutput gBMOutput) {
        super(key, gBMParameters, gBMOutput);
    }

    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
        EffectiveParametersUtils.initHistogramType((SharedTreeModel.SharedTreeParameters) this._parms);
        EffectiveParametersUtils.initCategoricalEncoding(this._parms, Model.Parameters.CategoricalEncodingScheme.Enum);
        EffectiveParametersUtils.initCalibrationMethod(this._parms);
    }

    public void initActualParamValuesAfterOutputSetup(int i, boolean z) {
        EffectiveParametersUtils.initStoppingMetric(this._parms, z);
        EffectiveParametersUtils.initDistribution(this._parms, i);
    }

    @Override // hex.tree.SharedTreeModelWithContributions
    protected SharedTreeModelWithContributions<GBMModel, GBMParameters, GBMOutput>.ScoreContributionsTask getScoreContributionsTask(SharedTreeModel sharedTreeModel) {
        return new SharedTreeModelWithContributions.ScoreContributionsTask(this);
    }

    @Override // hex.tree.SharedTreeModelWithContributions
    protected SharedTreeModelWithContributions<GBMModel, GBMParameters, GBMOutput>.ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel sharedTreeModel, Model.Contributions.ContributionsOptions contributionsOptions) {
        return new SharedTreeModelWithContributions.ScoreContributionsSortingTask(sharedTreeModel, contributionsOptions);
    }

    public Frame scoreStagedPredictions(Frame frame, Key<Frame> key) {
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        String[] makeAllTreeColumnNames = makeAllTreeColumnNames();
        return ((StagedPredictionsTask) new StagedPredictionsTask().doAll(makeAllTreeColumnNames.length, (byte) 3, frame2)).outputFrame(key, makeAllTreeColumnNames, (String[][]) null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.tree.SharedTreeModel
    public final double[] score0Incremental(Score.ScoreIncInfo scoreIncInfo, Chunk[] chunkArr, double d, int i, double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && ((GBMOutput) this._output).nfeatures() != dArr.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        if (scoreIncInfo._startTree == 0) {
            Arrays.fill(dArr2, 0.0d);
        } else {
            for (int i3 = 0; i3 < scoreIncInfo._workspaceColCnt; i3++) {
                dArr2[scoreIncInfo._predsAryOffset + i3] = chunkArr[scoreIncInfo._workspaceColIdx + i3].atd(i);
            }
        }
        score0(dArr, dArr2, d, scoreIncInfo._startTree, ((GBMOutput) this._output)._treeKeys.length);
        for (int i4 = 0; i4 < scoreIncInfo._workspaceColCnt; i4++) {
            chunkArr[scoreIncInfo._workspaceColIdx + i4].set(i, dArr2[scoreIncInfo._predsAryOffset + i4]);
        }
        score0Probabilities(dArr2, d);
        score0PostProcessSupervised(dArr2, dArr);
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.tree.SharedTreeModel
    public double[] score0(double[] dArr, double[] dArr2, double d, int i) {
        super.score0(dArr, dArr2, d, i);
        return score0Probabilities(dArr2, d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double[] score0Probabilities(double[] dArr, double d) {
        if (((GBMParameters) this._parms)._distribution == DistributionFamily.bernoulli || ((GBMParameters) this._parms)._distribution == DistributionFamily.quasibinomial || ((GBMParameters) this._parms)._distribution == DistributionFamily.modified_huber || (((GBMParameters) this._parms)._distribution == DistributionFamily.custom && ((GBMOutput) this._output).nclasses() == 2)) {
            dArr[2] = DistributionFactory.getDistribution(this._parms).linkInv(dArr[1] + ((GBMOutput) this._output)._init_f + d);
            dArr[1] = 1.0d - dArr[2];
        } else if (((GBMParameters) this._parms)._distribution == DistributionFamily.multinomial || (((GBMParameters) this._parms)._distribution == DistributionFamily.custom && ((GBMOutput) this._output).nclasses() > 2)) {
            if (((GBMOutput) this._output).nclasses() == 2) {
                dArr[1] = dArr[1] + ((GBMOutput) this._output)._init_f + d;
                dArr[2] = -dArr[1];
            }
            GenModel.GBM_rescale(dArr);
        } else {
            dArr[0] = DistributionFactory.getDistribution(this._parms).linkInv(dArr[0] + ((GBMOutput) this._output)._init_f + d);
        }
        return dArr;
    }

    @Override // hex.tree.SharedTreeModel
    protected SharedTreePojoWriter makeTreePojoWriter() {
        return new GbmPojoWriter(this, new CompressedForest(((GBMOutput) this._output)._treeKeys, ((GBMOutput) this._output)._domains).fetch()._trees);
    }

    /* renamed from: getMojo, reason: merged with bridge method [inline-methods] */
    public GbmMojoWriter m345getMojo() {
        return new GbmMojoWriter(this);
    }

    public FeatureInteractions getFeatureInteractions(int i, int i2, int i3) {
        FeatureInteractions featureInteractions = new FeatureInteractions();
        int i4 = ((GBMOutput) this._output)._nclasses > 2 ? ((GBMOutput) this._output)._nclasses : 1;
        for (int i5 = 0; i5 < ((GBMParameters) this._parms)._ntrees; i5++) {
            for (int i6 = 0; i6 < i4; i6++) {
                FeatureInteractions featureInteractions2 = new FeatureInteractions();
                FeatureInteractions.collectFeatureInteractions(getSharedTreeSubgraph(i5, i6).rootNode, new ArrayList(), 0.0d, 0.0d, 1.0d, 0, 0, featureInteractions2, new HashSet(), i, i2, i3, i5, true);
                featureInteractions.mergeWith(featureInteractions2);
            }
        }
        return featureInteractions;
    }

    public TwoDimTable[][] getFeatureInteractionsTable(int i, int i2, int i3) {
        return FeatureInteractions.getFeatureInteractionsTable(getFeatureInteractions(i, i2, i3));
    }

    public double getFriedmanPopescusH(Frame frame, String[] strArr) {
        Frame removeSpecialColumns = removeSpecialColumns(frame);
        for (int i = 0; i < removeSpecialColumns.numCols(); i++) {
            if (removeSpecialColumns.vec(i).isBad()) {
                throw new UnsupportedOperationException("Calculating of H statistics error: row " + removeSpecialColumns.name(i) + " is missing.");
            }
        }
        int i2 = ((GBMOutput) this._output)._nclasses > 2 ? ((GBMOutput) this._output)._nclasses : 1;
        SharedTreeSubgraph[][] sharedTreeSubgraphArr = new SharedTreeSubgraph[((GBMParameters) this._parms)._ntrees][i2];
        for (int i3 = 0; i3 < ((GBMParameters) this._parms)._ntrees; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                sharedTreeSubgraphArr[i3][i4] = getSharedTreeSubgraph(i3, i4);
            }
        }
        return FriedmanPopescusH.h(removeSpecialColumns, strArr, ((GBMParameters) this._parms)._learn_rate, sharedTreeSubgraphArr);
    }

    static {
        $assertionsDisabled = !GBMModel.class.desiredAssertionStatus();
    }
}
