package hex.tree.drf;

import hex.Model;
import hex.genmodel.CategoricalEncoding;
import hex.tree.CompressedTree;
import hex.tree.SharedTreePojoWriter;
import hex.tree.drf.DRFModel;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/tree/drf/DrfPojoWriter.class */
class DrfPojoWriter extends SharedTreePojoWriter {
    private final boolean _balance_classes;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DrfPojoWriter(DRFModel dRFModel, CompressedTree[][] compressedTreeArr) {
        super(dRFModel._key, dRFModel._output, dRFModel.getGenModelEncoding(), dRFModel.binomialOpt(), compressedTreeArr, ((DRFModel.DRFOutput) dRFModel._output)._treeStats);
        this._balance_classes = ((DRFModel.DRFParameters) dRFModel._parms)._balance_classes;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DrfPojoWriter(Model<?, ?, ?> model, CategoricalEncoding categoricalEncoding, boolean z, CompressedTree[][] compressedTreeArr, boolean z2) {
        super(model._key, model._output, categoricalEncoding, z, compressedTreeArr, null);
        this._balance_classes = z2;
    }

    @Override // hex.tree.SharedTreePojoWriter
    protected void toJavaUnifyPreds(SBPrintStream sBPrintStream) {
        if (this._output.nclasses() == 1) {
            sBPrintStream.ip("preds[0] /= " + this._trees.length + ";").nl();
            return;
        }
        if (this._output.nclasses() == 2 && this._binomialOpt) {
            sBPrintStream.ip("preds[1] /= " + this._trees.length + ";").nl();
            sBPrintStream.ip("preds[2] = 1.0 - preds[1];").nl();
        } else {
            sBPrintStream.ip("double sum = 0;").nl();
            sBPrintStream.ip("for(int i=1; i<preds.length; i++) { sum += preds[i]; }").nl();
            sBPrintStream.ip("if (sum>0) for(int i=1; i<preds.length; i++) { preds[i] /= sum; }").nl();
        }
        if (this._balance_classes) {
            sBPrintStream.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        sBPrintStream.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this._output.defaultThreshold() + ");").nl();
    }
}
