package ai.h2o.mojos.runtime.xgb;

import ai.h2o.mojos.runtime.tree.CompOp;
import ai.h2o.mojos.runtime.utils.BitUtils;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ai/h2o/mojos/runtime/xgb/TreeShap.class */
public class TreeShap {
    private static final CompOp[] COMP_OPS = CompOp.values();
    private List<PathElement> pathArray;
    private byte[] bb;
    private double[] inputs;
    private double[] phi;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/h2o/mojos/runtime/xgb/TreeShap$PathElement.class */
    public static class PathElement {
        public int featureIndex;
        public double zeroFraction;
        public double oneFraction;
        public double pweight;

        private PathElement() {
        }

        public String toString() {
            return "PathElement{featureIndex=" + this.featureIndex + ", zeroFraction=" + this.zeroFraction + ", oneFraction=" + this.oneFraction + ", pweight=" + this.pweight + '}';
        }
    }

    private boolean treeCompare(int i, double d) {
        return Double.isNaN(d) ? BitUtils.getByte(this.bb, i + 0) < -64 : COMP_OPS[BitUtils.getByte(this.bb, i + 0) & 15].compare(d, BitUtils.getDouble(this.bb, i + 1));
    }

    private void treeShap(int i, int i2, int i3, double d, double d2, int i4, double d3) {
        if (d3 == 0.0d) {
            return;
        }
        int i5 = i3 + i2 + 1;
        stdcopy(i3, i3 + i2 + 1, i5);
        extendPath(i5, i2, d, d2, i4);
        if (BitUtils.getByte(this.bb, i + 0) == 0) {
            double d4 = BitUtils.getDouble(this.bb, i + 1);
            for (int i6 = 1; i6 <= i2; i6++) {
                double unwoundPathSum = unwoundPathSum(i5, i2, i6);
                PathElement pathElement = this.pathArray.get(i5 + i6);
                double d5 = unwoundPathSum * (pathElement.oneFraction - pathElement.zeroFraction) * d4 * d3;
                double[] dArr = this.phi;
                int i7 = pathElement.featureIndex;
                dArr[i7] = dArr[i7] + d5;
            }
            return;
        }
        int i8 = BitUtils.getInt(this.bb, i + 17);
        boolean treeCompare = treeCompare(i, getFeatureValue(i8));
        int i9 = BitUtils.getInt(this.bb, i + 21);
        int i10 = BitUtils.getInt(this.bb, i + 25);
        int i11 = treeCompare ? i9 : i10;
        int i12 = treeCompare ? i10 : i9;
        double d6 = BitUtils.getDouble(this.bb, i + 9);
        if (d6 == 0.0d) {
            throw new UnsupportedOperationException("Tree model: dataCount is zero, cannot compute SHAP");
        }
        double d7 = BitUtils.getDouble(this.bb, i11 + 9) / d6;
        double d8 = BitUtils.getDouble(this.bb, i12 + 9) / d6;
        double d9 = 1.0d;
        double d10 = 1.0d;
        Integer findFeatureSplit = findFeatureSplit(i5, i2, i8);
        if (findFeatureSplit != null) {
            PathElement pathElement2 = this.pathArray.get(i5 + findFeatureSplit.intValue());
            d9 = pathElement2.zeroFraction;
            d10 = pathElement2.oneFraction;
            unwindPath(i5, i2, findFeatureSplit.intValue());
            i2--;
        }
        treeShap(i11, i2 + 1, i5, d7 * d9, d10, i8, d3);
        treeShap(i12, i2 + 1, i5, d8 * d9, 0.0d, i8, d3);
    }

    private Integer findFeatureSplit(int i, int i2, int i3) {
        for (int i4 = 0; i4 <= i2; i4++) {
            if (this.pathArray.get(i + i4).featureIndex == i3) {
                return Integer.valueOf(i4);
            }
        }
        return null;
    }

    private double getFeatureValue(int i) {
        return this.inputs[i];
    }

    private static int maxDepth(byte[] bArr, int i) {
        if (i < 0) {
            NodeUtils.dumpTree(bArr);
            throw new UnsupportedOperationException("startOffset=" + i);
        }
        if (BitUtils.getByte(bArr, i) == 0) {
            return 0;
        }
        return Math.max(maxDepth(bArr, BitUtils.getInt(bArr, i + 21)) + 1, maxDepth(bArr, BitUtils.getInt(bArr, i + 25)) + 1);
    }

    public void calculateContributions(Tree tree, double[] dArr, double[] dArr2) {
        BinaryTree binaryTree = (BinaryTree) tree;
        this.inputs = dArr;
        this.phi = dArr2;
        this.bb = binaryTree.bb;
        int length = dArr.length;
        dArr2[length] = dArr2[length] + binaryTree.expectedValue;
        int maxDepth = maxDepth(this.bb, 0) + 2;
        this.pathArray = createPathArray((maxDepth * (maxDepth + 1)) / 2);
        treeShap(0, 0, 0, 1.0d, 1.0d, -1, 1.0d);
    }

    private static ArrayList<PathElement> createPathArray(int i) {
        ArrayList<PathElement> arrayList = new ArrayList<>(i);
        for (int i2 = 1; i2 < i; i2++) {
            arrayList.add(new PathElement());
        }
        return arrayList;
    }

    private void stdcopy(int i, int i2, int i3) {
        while (i != i2) {
            PathElement pathElement = this.pathArray.get(i);
            PathElement pathElement2 = this.pathArray.get(i3);
            pathElement2.featureIndex = pathElement.featureIndex;
            pathElement2.zeroFraction = pathElement.zeroFraction;
            pathElement2.oneFraction = pathElement.oneFraction;
            pathElement2.pweight = pathElement.pweight;
            i3++;
            i++;
        }
    }

    private void extendPath(int i, int i2, double d, double d2, int i3) {
        PathElement pathElement = this.pathArray.get(i + i2);
        pathElement.featureIndex = i3;
        pathElement.zeroFraction = d;
        pathElement.oneFraction = d2;
        pathElement.pweight = i2 == 0 ? 1.0d : 0.0d;
        for (int i4 = i2 - 1; i4 >= 0; i4--) {
            PathElement pathElement2 = this.pathArray.get(i + i4 + 1);
            PathElement pathElement3 = this.pathArray.get(i + i4);
            pathElement2.pweight += ((d2 * pathElement3.pweight) * (i4 + 1)) / (i2 + 1);
            pathElement3.pweight = ((d * pathElement3.pweight) * (i2 - i4)) / (i2 + 1);
        }
    }

    private void unwindPath(int i, int i2, int i3) {
        double d = this.pathArray.get(i + i3).oneFraction;
        double d2 = this.pathArray.get(i + i3).zeroFraction;
        double d3 = this.pathArray.get(i + i2).pweight;
        for (int i4 = i2 - 1; i4 >= 0; i4--) {
            PathElement pathElement = this.pathArray.get(i + i4);
            double d4 = (i2 - i4) / (i2 + 1);
            if (d != 0.0d) {
                double d5 = pathElement.pweight;
                pathElement.pweight = (d3 * (i2 + 1)) / ((i4 + 1) * d);
                d3 = d5 - ((pathElement.pweight * d2) * d4);
            } else {
                pathElement.pweight = (pathElement.pweight / d2) / d4;
            }
        }
        for (int i5 = i3; i5 < i2; i5++) {
            PathElement pathElement2 = this.pathArray.get(i + i5);
            PathElement pathElement3 = this.pathArray.get(i + i5 + 1);
            pathElement2.featureIndex = pathElement3.featureIndex;
            pathElement2.zeroFraction = pathElement3.zeroFraction;
            pathElement2.oneFraction = pathElement3.oneFraction;
        }
    }

    private double unwoundPathSum(int i, int i2, int i3) {
        double d;
        double d2 = this.pathArray.get(i + i3).oneFraction;
        double d3 = this.pathArray.get(i + i3).zeroFraction;
        double d4 = this.pathArray.get(i + i2).pweight;
        double d5 = 0.0d;
        for (int i4 = i2 - 1; i4 >= 0; i4--) {
            PathElement pathElement = this.pathArray.get(i + i4);
            double d6 = (i2 - i4) / (i2 + 1);
            if (d2 != 0.0d) {
                d = (d4 * (i2 + 1)) / ((i4 + 1) * d2);
                d4 = pathElement.pweight - ((d * d3) * d6);
            } else {
                d = (pathElement.pweight / d3) / d6;
            }
            d5 += d;
        }
        return d5;
    }
}
