package hex.tree.drf;

import hex.tree.SharedTreeModel;
import water.Key;
import water.util.MathUtils;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/tree/drf/DRFModel.class */
public class DRFModel extends SharedTreeModel<DRFModel, DRFParameters, DRFOutput> {

    /* loaded from: input_file:hex/tree/drf/DRFModel$DRFOutput.class */
    public static class DRFOutput extends SharedTreeModel.SharedTreeOutput {
        public DRFOutput(DRF drf) {
            super(drf);
        }
    }

    /* loaded from: input_file:hex/tree/drf/DRFModel$DRFParameters.class */
    public static class DRFParameters extends SharedTreeModel.SharedTreeParameters {
        public boolean _binomial_double_trees = false;
        public int _mtries;

        public String algoName() {
            return "DRF";
        }

        public String fullName() {
            return "Distributed Random Forest";
        }

        public String javaName() {
            return DRFModel.class.getName();
        }

        public DRFParameters() {
            this._mtries = -1;
            this._mtries = -1;
            this._sample_rate = 0.6320000290870667d;
            this._max_depth = 20;
            this._min_rows = 1.0d;
        }
    }

    public DRFModel(Key<DRFModel> key, DRFParameters dRFParameters, DRFOutput dRFOutput) {
        super(key, dRFParameters, dRFOutput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.tree.SharedTreeModel
    public boolean binomialOpt() {
        return !((DRFParameters) this._parms)._binomial_double_trees;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.tree.SharedTreeModel
    public double[] score0(double[] dArr, double[] dArr2, double d, int i) {
        super.score0(dArr, dArr2, d, i);
        int i2 = ((DRFOutput) this._output)._ntrees;
        if (((DRFOutput) this._output).nclasses() == 1) {
            if (i2 >= 1) {
                dArr2[0] = dArr2[0] / i2;
            }
        } else if (((DRFOutput) this._output).nclasses() == 2 && binomialOpt()) {
            if (i2 >= 1) {
                dArr2[1] = dArr2[1] / i2;
            }
            dArr2[2] = 1.0d - dArr2[1];
        } else {
            double sum = MathUtils.sum(dArr2);
            if (sum > 0.0d) {
                MathUtils.div(dArr2, sum);
            }
        }
        return dArr2;
    }

    @Override // hex.tree.SharedTreeModel
    protected void toJavaUnifyPreds(SBPrintStream sBPrintStream) {
        if (((DRFOutput) this._output).nclasses() == 1) {
            sBPrintStream.ip("preds[0] /= " + ((DRFOutput) this._output)._ntrees + ";").nl();
            return;
        }
        if (((DRFOutput) this._output).nclasses() == 2 && binomialOpt()) {
            sBPrintStream.ip("preds[1] /= " + ((DRFOutput) this._output)._ntrees + ";").nl();
            sBPrintStream.ip("preds[2] = 1.0 - preds[1];").nl();
        } else {
            sBPrintStream.ip("double sum = 0;").nl();
            sBPrintStream.ip("for(int i=1; i<preds.length; i++) { sum += preds[i]; }").nl();
            sBPrintStream.ip("if (sum>0) for(int i=1; i<preds.length; i++) { preds[i] /= sum; }").nl();
        }
        if (((DRFParameters) this._parms)._balance_classes) {
            sBPrintStream.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        sBPrintStream.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
    }

    /* renamed from: getMojo, reason: merged with bridge method [inline-methods] */
    public DrfMojoWriter m232getMojo() {
        return new DrfMojoWriter(this);
    }
}
