package hex.tree.gbm;

import hex.Distribution;
import hex.DistributionFactory;
import hex.LinkFunction;
import hex.Model;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.CompressedTree;
import hex.tree.SharedTreePojoWriter;
import hex.tree.gbm.GBMModel;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/tree/gbm/GbmPojoWriter.class */
class GbmPojoWriter extends SharedTreePojoWriter {
    private final double _init_f;
    private final boolean _balance_classes;
    private final DistributionFamily _distribution_family;
    private final LinkFunction _link_function;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GbmPojoWriter(GBMModel gBMModel, CompressedTree[][] compressedTreeArr) {
        super(gBMModel._key, gBMModel._output, gBMModel.getGenModelEncoding(), gBMModel.binomialOpt(), compressedTreeArr, ((GBMModel.GBMOutput) gBMModel._output)._treeStats);
        this._init_f = ((GBMModel.GBMOutput) gBMModel._output)._init_f;
        this._balance_classes = ((GBMModel.GBMParameters) gBMModel._parms)._balance_classes;
        Distribution distribution = DistributionFactory.getDistribution(gBMModel._parms);
        this._distribution_family = distribution._family;
        this._link_function = distribution._linkFunction;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GbmPojoWriter(Model<?, ?, ?> model, CategoricalEncoding categoricalEncoding, boolean z, CompressedTree[][] compressedTreeArr, double d, boolean z2, DistributionFamily distributionFamily, LinkFunction linkFunction) {
        super(model._key, model._output, categoricalEncoding, z, compressedTreeArr, null);
        this._init_f = d;
        this._balance_classes = z2;
        this._distribution_family = distributionFamily;
        this._link_function = linkFunction;
    }

    @Override // hex.tree.SharedTreePojoWriter
    protected void toJavaUnifyPreds(SBPrintStream sBPrintStream) {
        if (this._distribution_family == DistributionFamily.bernoulli || this._distribution_family == DistributionFamily.quasibinomial || this._distribution_family == DistributionFamily.modified_huber) {
            sBPrintStream.ip("preds[2] = preds[1] + ").p(this._init_f).p(";").nl();
            sBPrintStream.ip("preds[2] = " + this._link_function.linkInvString("preds[2]") + ";").nl();
            sBPrintStream.ip("preds[1] = 1.0-preds[2];").nl();
            if (this._balance_classes) {
                sBPrintStream.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            sBPrintStream.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this._output.defaultThreshold() + ");").nl();
            return;
        }
        if (this._output.nclasses() == 1) {
            sBPrintStream.ip("preds[0] += ").p(this._init_f).p(";").nl();
            sBPrintStream.ip("preds[0] = " + this._link_function.linkInvString("preds[0]") + ";").nl();
            return;
        }
        if (this._output.nclasses() == 2) {
            sBPrintStream.ip("preds[1] += ").p(this._init_f).p(";").nl();
            sBPrintStream.ip("preds[2] = - preds[1];").nl();
        }
        sBPrintStream.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (this._balance_classes) {
            sBPrintStream.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        sBPrintStream.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this._output.defaultThreshold() + ");").nl();
    }
}
