package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.Dart;
import biz.k11i.xgboost.gbm.GBLinear;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.LinkFunctionFactory;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.utils.LinkFunctionType;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/tree/xgboost/XGBoostPojoWriter.class */
public abstract class XGBoostPojoWriter {
    protected final Predictor _p;
    protected final String _namePrefix;
    protected final XGBoostOutput _output;
    private final double _defaultThreshold;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoostPojoWriter$XGBoostPojoLinearWriter.class */
    public static class XGBoostPojoLinearWriter extends XGBoostPojoWriter {
        protected XGBoostPojoLinearWriter(Predictor predictor, String str, XGBoostOutput xGBoostOutput, double d) {
            super(predictor, str, xGBoostOutput, d);
        }

        @Override // hex.tree.xgboost.XGBoostPojoWriter
        public void renderComputePredict(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline) {
            GBLinear booster = this._p.getBooster();
            for (int i = 0; i < booster.getNumOutputGroup(); i++) {
                sBPrintStream.ip("preds[").p(i).p("] =").nl();
                sBPrintStream.ii(1);
                for (int i2 = 0; i2 < booster.getNumFeature(); i2++) {
                    String featureAccessor = getFeatureAccessor(i2);
                    sBPrintStream.ip("(Double.isNaN(").p(featureAccessor).p(") ? 0 : (").pj(booster.weight(i2, i)).p(" * ").p(featureAccessor).p(")) + ").nl();
                }
                sBPrintStream.ip("").pj(booster.bias(i)).p(" +").nl();
                sBPrintStream.ip("").pj(this._p.getBaseScore()).p(";").nl();
                sBPrintStream.di(1);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoostPojoWriter$XGBoostPojoTreeWriter.class */
    public static class XGBoostPojoTreeWriter extends XGBoostPojoWriter {
        protected XGBoostPojoTreeWriter(Predictor predictor, String str, XGBoostOutput xGBoostOutput, double d) {
            super(predictor, str, xGBoostOutput, d);
        }

        @Override // hex.tree.xgboost.XGBoostPojoWriter
        public void renderComputePredict(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline) {
            GBTree booster = this._p.getBooster();
            Dart dart = booster instanceof Dart ? (Dart) booster : null;
            RegTree[][] groupedTrees = booster.getGroupedTrees();
            for (int i = 0; i < groupedTrees.length; i++) {
                sBPrintStream.ip("float preds_").p(i).p(" = 0f;").nl();
                for (int i2 = 0; i2 < groupedTrees[i].length; i2++) {
                    sBPrintStream.ip("preds_").p(i).p(" += ").p(renderTreeClass(groupedTrees, i, i2, dart, codeGeneratorPipeline)).p(".score0(data);").nl();
                }
                sBPrintStream.ip("preds_").p(i).p(" += ").pj(this._p.getBaseScore()).p(";").nl();
                sBPrintStream.ip("preds[").p(i).p("] = preds_").p(i).p(";").nl();
            }
        }

        private String renderTreeClass(RegTree[][] regTreeArr, int i, final int i2, final Dart dart, CodeGeneratorPipeline codeGeneratorPipeline) {
            final RegTree regTree = regTreeArr[i][i2];
            final String str = this._namePrefix + "_Tree_g_" + i + "_t_" + i2;
            codeGeneratorPipeline.add(new CodeGenerator() { // from class: hex.tree.xgboost.XGBoostPojoWriter.XGBoostPojoTreeWriter.1
                public void generate(JCodeSB jCodeSB) {
                    jCodeSB.nl().p("class ").p(str).p(" {").nl();
                    jCodeSB.ii(1);
                    jCodeSB.ip("static float score0(double[] data) {").nl();
                    jCodeSB.ii(1);
                    jCodeSB.ip("return ");
                    if (dart != null) {
                        jCodeSB.pj(dart.weight(i2)).p(" * ");
                    }
                    XGBoostPojoTreeWriter.this.renderTree(jCodeSB, regTree, 0);
                    jCodeSB.p(";").nl();
                    jCodeSB.di(1);
                    jCodeSB.ip("}").nl();
                    jCodeSB.di(1);
                    jCodeSB.ip("}").nl();
                }
            });
            return str;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void renderTree(JCodeSB jCodeSB, RegTree regTree, int i) {
            String str;
            int rightChildIndex;
            int leftChildIndex;
            RegTreeNode regTreeNode = regTree.getNodes()[i];
            if (regTreeNode.isLeaf()) {
                jCodeSB.ip("").pj(regTreeNode.getLeafValue());
                return;
            }
            String featureAccessor = getFeatureAccessor(regTreeNode.getSplitIndex());
            if (regTreeNode.default_left()) {
                str = " < ";
                rightChildIndex = regTreeNode.getLeftChildIndex();
                leftChildIndex = regTreeNode.getRightChildIndex();
            } else {
                str = " >= ";
                rightChildIndex = regTreeNode.getRightChildIndex();
                leftChildIndex = regTreeNode.getLeftChildIndex();
            }
            jCodeSB.ip("((Double.isNaN(").p(featureAccessor).p(") || ((float)").p(featureAccessor).p(")").p(str).pj(regTreeNode.getSplitCondition()).p(") ?").nl();
            jCodeSB.ii(1);
            renderTree(jCodeSB, regTree, rightChildIndex);
            jCodeSB.nl().ip(":").nl();
            renderTree(jCodeSB, regTree, leftChildIndex);
            jCodeSB.di(1);
            jCodeSB.nl().ip(")");
        }
    }

    public static XGBoostPojoWriter make(Predictor predictor, String str, XGBoostOutput xGBoostOutput, double d) {
        return predictor.getBooster() instanceof GBTree ? new XGBoostPojoTreeWriter(predictor, str, xGBoostOutput, d) : new XGBoostPojoLinearWriter(predictor, str, xGBoostOutput, d);
    }

    protected XGBoostPojoWriter(Predictor predictor, String str, XGBoostOutput xGBoostOutput, double d) {
        this._p = predictor;
        this._namePrefix = str;
        this._output = xGBoostOutput;
        this._defaultThreshold = d;
    }

    protected String getFeatureAccessor(int i) {
        if (i >= this._output._catOffsets[this._output._cats]) {
            int i2 = (i - this._output._catOffsets[this._output._cats]) + this._output._cats;
            return this._output._sparse ? "(data[" + i2 + "] == 0 ? Double.NaN : data[" + i2 + "])" : "data[" + i2 + "]";
        }
        int i3 = 0;
        while (i >= this._output._catOffsets[i3 + 1]) {
            i3++;
        }
        return "(data[" + i3 + "] == " + (i - this._output._catOffsets[i3]) + " ? 1 : " + (this._output._sparse ? "Float.NaN" : "0") + ")";
    }

    private void renderPredTransformViaLinkFunction(LinkFunctionType linkFunctionType, SBPrintStream sBPrintStream) {
        sBPrintStream.ip("preds[0] = (float) ").p(LinkFunctionFactory.getLinkFunction(linkFunctionType).linkInvString("preds[0]")).p(";").nl();
    }

    private void renderMultiClassPredTransform(SBPrintStream sBPrintStream) {
        sBPrintStream.ip("double max = preds[0];").nl();
        sBPrintStream.ip("for (int i = 1; i < preds.length-1; i++) max = Math.max(preds[i], max); ").nl();
        sBPrintStream.ip("double sum = 0.0D;").nl();
        sBPrintStream.ip("for (int i = 0; i < preds.length-1; i++) {").nl();
        sBPrintStream.ip("  preds[i] = Math.exp(preds[i] - max);").nl();
        sBPrintStream.ip("  sum += preds[i];").nl();
        sBPrintStream.ip("}").nl();
        sBPrintStream.ip("for (int i = 0; i < preds.length-1; i++) {").nl();
        sBPrintStream.ip("  preds[i] /= (float) sum;").nl();
        sBPrintStream.ip("}").nl();
    }

    private void renderPredTransform(SBPrintStream sBPrintStream) {
        String objName = this._p.getObjName();
        if (XGBoostMojoModel.ObjectiveType.REG_GAMMA.getId().equals(objName) || XGBoostMojoModel.ObjectiveType.REG_TWEEDIE.getId().equals(objName) || XGBoostMojoModel.ObjectiveType.COUNT_POISSON.getId().equals(objName)) {
            renderPredTransformViaLinkFunction(LinkFunctionType.log, sBPrintStream);
            return;
        }
        if (XGBoostMojoModel.ObjectiveType.BINARY_LOGISTIC.getId().equals(objName)) {
            renderPredTransformViaLinkFunction(LinkFunctionType.logit, sBPrintStream);
            return;
        }
        if (XGBoostMojoModel.ObjectiveType.REG_LINEAR.getId().equals(objName) || XGBoostMojoModel.ObjectiveType.REG_SQUAREDERROR.getId().equals(objName) || XGBoostMojoModel.ObjectiveType.RANK_PAIRWISE.getId().equals(objName)) {
            renderPredTransformViaLinkFunction(LinkFunctionType.identity, sBPrintStream);
        } else {
            if (!XGBoostMojoModel.ObjectiveType.MULTI_SOFTPROB.getId().equals(objName)) {
                throw new IllegalArgumentException("Unexpected objFunction " + objName);
            }
            renderMultiClassPredTransform(sBPrintStream);
        }
    }

    private void renderPredPostProcess(SBPrintStream sBPrintStream) {
        if (this._output.nclasses() > 2) {
            sBPrintStream.ip("for (int i = preds.length-2; i >= 0; i--)").nl();
            sBPrintStream.ip("  preds[1 + i] = preds[i];").nl();
            sBPrintStream.ip("preds[0] = GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, ").pj(this._defaultThreshold).p(");").nl();
        } else if (this._output.nclasses() == 2) {
            sBPrintStream.ip("preds[1] = 1f - preds[0];").nl();
            sBPrintStream.ip("preds[2] = preds[0];").nl();
            sBPrintStream.ip("preds[0] = GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, ").pj(this._defaultThreshold).p(");").nl();
        }
    }

    public void renderJavaPredictBody(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline) {
        renderComputePredict(sBPrintStream, codeGeneratorPipeline);
        renderPredTransform(sBPrintStream);
        renderPredPostProcess(sBPrintStream);
    }

    protected abstract void renderComputePredict(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline);
}
