package hex.tree;

import hex.Model;
import hex.PojoWriter;
import hex.genmodel.CategoricalEncoding;
import org.apache.commons.math3.geometry.VectorFormat;
import water.Key;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.JCodeGen;
import water.util.PojoUtils;
import water.util.SB;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/tree/SharedTreePojoWriter.class */
public abstract class SharedTreePojoWriter implements PojoWriter {
    protected final Key<?> _modelKey;
    protected final Model.Output _output;
    protected final CategoricalEncoding _encoding;
    protected final boolean _binomialOpt;
    protected final CompressedTree[][] _trees;
    protected final TreeStats _treeStats;

    /* JADX INFO: Access modifiers changed from: protected */
    public SharedTreePojoWriter(Key<?> key, Model.Output output, CategoricalEncoding categoricalEncoding, boolean z, CompressedTree[][] compressedTreeArr, TreeStats treeStats) {
        this._modelKey = key;
        this._output = output;
        this._encoding = categoricalEncoding;
        this._binomialOpt = z;
        this._trees = compressedTreeArr;
        this._treeStats = treeStats;
    }

    @Override // hex.PojoWriter
    public boolean toJavaCheckTooBig() {
        return this._treeStats == null || ((float) this._treeStats._num_trees) * this._treeStats._mean_leaves > 1000000.0f;
    }

    @Override // hex.PojoWriter
    public SBPrintStream toJavaInit(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline) {
        sBPrintStream.nl();
        sBPrintStream.ip("public boolean isSupervised() { return true; }").nl();
        sBPrintStream.ip("public int nfeatures() { return " + this._output.nfeatures() + "; }").nl();
        sBPrintStream.ip("public int nclasses() { return " + this._output.nclasses() + "; }").nl();
        if (this._encoding == CategoricalEncoding.Eigen) {
            sBPrintStream.ip("public double[] getOrigProjectionArray() { return " + PojoUtils.toJavaDoubleArray(this._output._orig_projection_array) + "; }").nl();
        }
        if (this._encoding != CategoricalEncoding.AUTO) {
            sBPrintStream.ip("public hex.genmodel.CategoricalEncoding getCategoricalEncoding() { return hex.genmodel.CategoricalEncoding." + this._encoding.name() + "; }").nl();
        }
        return sBPrintStream;
    }

    @Override // hex.PojoWriter
    public void toJavaPredictBody(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline, CodeGeneratorPipeline codeGeneratorPipeline2, boolean z) {
        int nclasses = this._output.nclasses();
        sBPrintStream.ip("java.util.Arrays.fill(preds,0);").nl();
        String javaId = JCodeGen.toJavaId(this._modelKey.toString());
        for (int i = 0; i < this._trees.length; i++) {
            ((SBPrintStream) toJavaForestName(sBPrintStream.i(), javaId, i)).p(".score0(data,preds);").nl();
            int i2 = i;
            codeGeneratorPipeline2.add(jCodeSB -> {
                try {
                    jCodeSB.nl();
                    toJavaForestName(jCodeSB.ip("class "), javaId, i2).p(" {").nl().ii(1);
                    jCodeSB.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
                    for (int i3 = 0; i3 < nclasses; i3++) {
                        if (this._trees[i2][i3] != null && (!this._binomialOpt || i3 != 1 || nclasses != 2)) {
                            toJavaTreeName(jCodeSB.ip("preds[").p(nclasses == 1 ? 0 : i3 + 1).p("] += "), javaId, i2, i3).p(".score0(fdata);").nl();
                        }
                    }
                    jCodeSB.di(1).ip(VectorFormat.DEFAULT_SUFFIX).nl();
                    jCodeSB.di(1).ip(VectorFormat.DEFAULT_SUFFIX).nl();
                    for (int i4 = 0; i4 < nclasses; i4++) {
                        if (this._trees[i2][i4] != null && (!this._binomialOpt || i4 != 1 || nclasses != 2)) {
                            String sb = ((SB) toJavaTreeName(new SB(), javaId, i2, i4)).toString();
                            SB sb2 = new SB();
                            new TreeJCodeGen(this._output, this._trees[i2][i4], sb2, sb, z).generate();
                            jCodeSB.p(sb2);
                        }
                    }
                } catch (Exception e) {
                    throw new RuntimeException("Internal error creating the POJO.", e);
                }
            });
        }
        toJavaUnifyPreds(sBPrintStream);
    }

    protected abstract void toJavaUnifyPreds(SBPrintStream sBPrintStream);

    private static <T extends JCodeSB<T>> T toJavaTreeName(T t, String str, int i, int i2) {
        return (T) t.p(str).p("_Tree_").p(i).p("_class_").p(i2);
    }

    private static <T extends JCodeSB<T>> T toJavaForestName(T t, String str, int i) {
        return (T) t.p(str).p("_Forest_").p(i);
    }
}
