package ai.h2o.mojos.runtime.transforms;

import ai.h2o.mojos.runtime.api.BasePipelineListener;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.frame.NumericFrame;
import ai.h2o.mojos.runtime.transforms.util.MojoTransformBuilderUtils;
import ai.h2o.mojos.runtime.xgb.Tree;
import ai.h2o.mojos.runtime.xgb.TreeShap;
import java.util.Arrays;
import javassist.bytecode.Opcode;

/* loaded from: input_file:ai/h2o/mojos/runtime/transforms/MojoTransformTreeModelBuilder.class */
public class MojoTransformTreeModelBuilder extends ShapCapableTransform implements HasTraceableSteps {
    private BasePipelineListener listener;
    private double baseMargin;
    private Tree[] trees;
    private int[] treeInfo;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MojoTransformTreeModelBuilder(MojoFrameMeta mojoFrameMeta, int[] iArr, int[] iArr2) {
        super(iArr, iArr2);
        this.listener = BasePipelineListener.NOOP;
        validate(mojoFrameMeta);
    }

    private void validate(MojoFrameMeta mojoFrameMeta) {
        MojoTransformBuilderUtils.assertTypes(mojoFrameMeta, this.iindices, Opcode.INSTANCEOF, "Input columns must have the same float type");
        MojoTransformBuilderUtils.assertTypes(mojoFrameMeta, this.oindices, Opcode.INSTANCEOF, "Output columns must have the same float type");
        if (!$assertionsDisabled && this.iindices.length != 0 && this.oindices.length != 0 && mojoFrameMeta.getColumnType(this.iindices[0]) != mojoFrameMeta.getColumnType(this.oindices[0])) {
            throw new AssertionError("Input and output columns must be of the same type");
        }
    }

    public void setBaseMargin(double d) {
        this.baseMargin = d;
    }

    public void setTreeBooster(Tree[] treeArr, int[] iArr) {
        this.trees = treeArr;
        this.treeInfo = iArr;
        if (!$assertionsDisabled && treeArr.length != iArr.length) {
            throw new AssertionError();
        }
    }

    @Override // ai.h2o.mojos.runtime.transforms.HasTraceableSteps
    public void setListener(BasePipelineListener basePipelineListener) {
        this.listener = basePipelineListener;
    }

    @Override // ai.h2o.mojos.runtime.transforms.MojoTransform
    public void transform(MojoFrame mojoFrame) {
        NumericFrame indexed = NumericFrame.indexed(mojoFrame, this.iindices);
        NumericFrame indexed2 = NumericFrame.indexed(mojoFrame, this.oindices);
        int length = this.oindices.length;
        for (int i = 0; i < length; i++) {
            indexed2.fillColumn(i, 0.0d);
        }
        predict(indexed, indexed2);
        for (int i2 = 0; i2 < length; i2++) {
            indexed2.addConstantToColumn(i2, this.baseMargin);
        }
    }

    private void predict(NumericFrame numericFrame, NumericFrame numericFrame2) {
        int length = this.trees.length;
        if (numericFrame.isDouble()) {
            for (int i = 0; i < length; i++) {
                this.trees[i].predict(numericFrame.doubleColumns(), numericFrame2.doubleColumns()[this.treeInfo[i]]);
                this.listener.onTransformStep("tree " + i, this.oindices[this.treeInfo[i]]);
            }
            return;
        }
        for (int i2 = 0; i2 < length; i2++) {
            this.trees[i2].predict(numericFrame.floatColumns(), numericFrame2.floatColumns()[this.treeInfo[i2]]);
            this.listener.onTransformStep("tree " + i2, this.oindices[this.treeInfo[i2]]);
        }
    }

    @Override // ai.h2o.mojos.runtime.transforms.ShapCapable
    public void computeShap(double[] dArr, double[][] dArr2) {
        TreeShap treeShap = new TreeShap();
        for (int i = 0; i < this.oindices.length; i++) {
            double[] dArr3 = dArr2[i];
            Arrays.fill(dArr3, -0.0d);
            for (int i2 = 0; i2 < this.trees.length; i2++) {
                if (this.treeInfo[i2] == i) {
                    treeShap.calculateContributions(this.trees[i2], dArr, dArr3);
                }
            }
            int length = dArr.length;
            dArr3[length] = dArr3[length] + this.baseMargin;
        }
    }

    static {
        $assertionsDisabled = !MojoTransformTreeModelBuilder.class.desiredAssertionStatus();
    }
}
