package hex.tree.gbm;

import hex.Distribution;
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.SharedTreeModel;
import water.Key;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/tree/gbm/GBMModel.class */
public class GBMModel extends SharedTreeModel<GBMModel, GBMParameters, GBMOutput> {

    /* loaded from: input_file:hex/tree/gbm/GBMModel$GBMOutput.class */
    public static class GBMOutput extends SharedTreeModel.SharedTreeOutput {
        public GBMOutput(GBM gbm) {
            super(gbm);
        }
    }

    /* loaded from: input_file:hex/tree/gbm/GBMModel$GBMParameters.class */
    public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
        public double _learn_rate = 0.1d;
        public double _learn_rate_annealing = 1.0d;
        public double _col_sample_rate = 1.0d;
        public double _max_abs_leafnode_pred;
        public double _pred_noise_bandwidth;

        public GBMParameters() {
            this._sample_rate = 1.0d;
            this._ntrees = 50;
            this._max_depth = 5;
            this._max_abs_leafnode_pred = Double.MAX_VALUE;
            this._pred_noise_bandwidth = 0.0d;
        }

        public String algoName() {
            return "GBM";
        }

        public String fullName() {
            return "Gradient Boosting Machine";
        }

        public String javaName() {
            return GBMModel.class.getName();
        }
    }

    public GBMModel(Key<GBMModel> key, GBMParameters gBMParameters, GBMOutput gBMOutput) {
        super(key, gBMParameters, gBMOutput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.tree.SharedTreeModel
    public double[] score0(double[] dArr, double[] dArr2, double d, int i) {
        super.score0(dArr, dArr2, d, i);
        if (((GBMParameters) this._parms)._distribution == DistributionFamily.bernoulli || ((GBMParameters) this._parms)._distribution == DistributionFamily.modified_huber) {
            dArr2[2] = new Distribution(this._parms).linkInv(dArr2[1] + ((GBMOutput) this._output)._init_f + d);
            dArr2[1] = 1.0d - dArr2[2];
        } else if (((GBMParameters) this._parms)._distribution == DistributionFamily.multinomial) {
            if (((GBMOutput) this._output).nclasses() == 2) {
                dArr2[1] = dArr2[1] + ((GBMOutput) this._output)._init_f + d;
                dArr2[2] = -dArr2[1];
            }
            GenModel.GBM_rescale(dArr2);
        } else {
            dArr2[0] = new Distribution(this._parms).linkInv(dArr2[0] + ((GBMOutput) this._output)._init_f + d);
        }
        return dArr2;
    }

    @Override // hex.tree.SharedTreeModel
    protected void toJavaUnifyPreds(SBPrintStream sBPrintStream) {
        if (((GBMParameters) this._parms)._distribution == DistributionFamily.bernoulli || ((GBMParameters) this._parms)._distribution == DistributionFamily.modified_huber) {
            sBPrintStream.ip("preds[2] = preds[1] + ").p(((GBMOutput) this._output)._init_f).p(";").nl();
            sBPrintStream.ip("preds[2] = " + new Distribution(this._parms).linkInvString("preds[2]") + ";").nl();
            sBPrintStream.ip("preds[1] = 1.0-preds[2];").nl();
            if (((GBMParameters) this._parms)._balance_classes) {
                sBPrintStream.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            sBPrintStream.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
            return;
        }
        if (((GBMOutput) this._output).nclasses() == 1) {
            sBPrintStream.ip("preds[0] += ").p(((GBMOutput) this._output)._init_f).p(";").nl();
            sBPrintStream.ip("preds[0] = " + new Distribution(this._parms).linkInvString("preds[0]") + ";").nl();
            return;
        }
        if (((GBMOutput) this._output).nclasses() == 2) {
            sBPrintStream.ip("preds[1] += ").p(((GBMOutput) this._output)._init_f).p(";").nl();
            sBPrintStream.ip("preds[2] = - preds[1];").nl();
        }
        sBPrintStream.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (((GBMParameters) this._parms)._balance_classes) {
            sBPrintStream.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        sBPrintStream.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
    }

    /* renamed from: getMojo, reason: merged with bridge method [inline-methods] */
    public GbmMojoWriter m230getMojo() {
        return new GbmMojoWriter(this);
    }
}
