package hex.genmodel.algos.tree;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import java.io.Serializable;

/* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/tree/TreeSHAP.class */
public class TreeSHAP<R, N extends INode<R>, S extends INodeStat> implements TreeSHAPPredictor<R> {
    private final int rootNodeId;
    private final N[] nodes;
    private final S[] stats;
    private final float expectedTreeValue = treeMeanValue();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/tree/TreeSHAP$PathElement.class */
    public static class PathElement implements Serializable {
        int feature_index;
        float zero_fraction;
        float one_fraction;
        float pweight;

        private PathElement() {
        }

        void reset() {
            this.feature_index = 0;
            this.zero_fraction = 0.0f;
            this.one_fraction = 0.0f;
            this.pweight = 0.0f;
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/tree/TreeSHAP$PathPointer.class */
    public static class PathPointer implements TreeSHAPPredictor.Workspace {
        PathElement[] path;
        int position;

        PathPointer(PathElement[] pathElementArr) {
            this.path = pathElementArr;
        }

        PathPointer(PathElement[] pathElementArr, int i) {
            this.path = pathElementArr;
            this.position = i;
        }

        PathElement get(int i) {
            return this.path[this.position + i];
        }

        PathPointer move(int i) {
            for (int i2 = 0; i2 < i; i2++) {
                this.path[this.position + i + i2].feature_index = this.path[this.position + i2].feature_index;
                this.path[this.position + i + i2].zero_fraction = this.path[this.position + i2].zero_fraction;
                this.path[this.position + i + i2].one_fraction = this.path[this.position + i2].one_fraction;
                this.path[this.position + i + i2].pweight = this.path[this.position + i2].pweight;
            }
            return new PathPointer(this.path, this.position + i);
        }

        void reset() {
            this.path[0].reset();
        }

        @Override // hex.genmodel.algos.tree.TreeSHAPPredictor.Workspace
        public int getSize() {
            return this.path.length;
        }
    }

    public TreeSHAP(N[] nArr, S[] sArr, int i) {
        this.rootNodeId = i;
        this.nodes = nArr;
        this.stats = sArr;
    }

    private void extendPath(PathPointer pathPointer, int i, float f, float f2, int i2) {
        pathPointer.get(i).feature_index = i2;
        pathPointer.get(i).zero_fraction = f;
        pathPointer.get(i).one_fraction = f2;
        pathPointer.get(i).pweight = i == 0 ? 1.0f : 0.0f;
        for (int i3 = i - 1; i3 >= 0; i3--) {
            pathPointer.get(i3 + 1).pweight += ((f2 * pathPointer.get(i3).pweight) * (i3 + 1)) / (i + 1);
            pathPointer.get(i3).pweight = ((f * pathPointer.get(i3).pweight) * (i - i3)) / (i + 1);
        }
    }

    private void unwindPath(PathPointer pathPointer, int i, int i2) {
        float f = pathPointer.get(i2).one_fraction;
        float f2 = pathPointer.get(i2).zero_fraction;
        float f3 = pathPointer.get(i).pweight;
        for (int i3 = i - 1; i3 >= 0; i3--) {
            if (f != 0.0f) {
                float f4 = pathPointer.get(i3).pweight;
                pathPointer.get(i3).pweight = (f3 * (i + 1)) / ((i3 + 1) * f);
                f3 = f4 - (((pathPointer.get(i3).pweight * f2) * (i - i3)) / (i + 1));
            } else {
                pathPointer.get(i3).pweight = (pathPointer.get(i3).pweight * (i + 1)) / (f2 * (i - i3));
            }
        }
        for (int i4 = i2; i4 < i; i4++) {
            pathPointer.get(i4).feature_index = pathPointer.get(i4 + 1).feature_index;
            pathPointer.get(i4).zero_fraction = pathPointer.get(i4 + 1).zero_fraction;
            pathPointer.get(i4).one_fraction = pathPointer.get(i4 + 1).one_fraction;
        }
    }

    private float unwoundPathSum(PathPointer pathPointer, int i, int i2) {
        float f = pathPointer.get(i2).one_fraction;
        float f2 = pathPointer.get(i2).zero_fraction;
        float f3 = pathPointer.get(i).pweight;
        float f4 = 0.0f;
        for (int i3 = i - 1; i3 >= 0; i3--) {
            if (f != 0.0f) {
                float f5 = (f3 * (i + 1)) / ((i3 + 1) * f);
                f4 += f5;
                f3 = pathPointer.get(i3).pweight - ((f5 * f2) * ((i - i3) / (i + 1)));
            } else if (f2 != 0.0f) {
                f4 += (pathPointer.get(i3).pweight / f2) / ((i - i3) / (i + 1));
            } else if (pathPointer.get(i3).pweight != 0.0f) {
                throw new IllegalStateException("Unique path " + i3 + " must have zero getWeight");
            }
        }
        return f4;
    }

    private void treeShap(R r, float[] fArr, N n, S s, int i, PathPointer pathPointer, float f, float f2, int i2, int i3, int i4, float f3) {
        if (f3 == 0.0f) {
            return;
        }
        PathPointer move = pathPointer.move(i);
        if (i3 == 0 || i4 != i2) {
            extendPath(move, i, f, f2, i2);
        }
        int splitIndex = n.getSplitIndex();
        if (n.isLeaf()) {
            for (int i5 = 1; i5 <= i; i5++) {
                float unwoundPathSum = unwoundPathSum(move, i, i5);
                PathElement pathElement = move.get(i5);
                int i6 = pathElement.feature_index;
                fArr[i6] = fArr[i6] + (unwoundPathSum * (pathElement.one_fraction - pathElement.zero_fraction) * n.getLeafValue() * f3);
            }
            return;
        }
        int next = n.next(r);
        int rightChildIndex = next == n.getLeftChildIndex() ? n.getRightChildIndex() : n.getLeftChildIndex();
        float weight = s.getWeight();
        float weight2 = this.stats[next].getWeight() / weight;
        float weight3 = this.stats[rightChildIndex].getWeight() / weight;
        float f4 = 1.0f;
        float f5 = 1.0f;
        int i7 = 0;
        while (i7 <= i && move.get(i7).feature_index != splitIndex) {
            i7++;
        }
        if (i7 != i + 1) {
            f4 = move.get(i7).zero_fraction;
            f5 = move.get(i7).one_fraction;
            unwindPath(move, i, i7);
            i--;
        }
        float f6 = f3;
        float f7 = f3;
        if (i3 > 0 && splitIndex == i4) {
            f7 = 0.0f;
            i--;
        } else if (i3 < 0 && splitIndex == i4) {
            f6 *= weight2;
            f7 *= weight3;
            i--;
        }
        treeShap(r, fArr, this.nodes[next], this.stats[next], i + 1, move, weight2 * f4, f5, splitIndex, i3, i4, f6);
        treeShap(r, fArr, this.nodes[rightChildIndex], this.stats[rightChildIndex], i + 1, move, weight3 * f4, 0.0f, splitIndex, i3, i4, f7);
    }

    @Override // hex.genmodel.algos.tree.TreeSHAPPredictor
    public float[] calculateContributions(R r, float[] fArr) {
        return calculateContributions(r, fArr, 0, -1, makeWorkspace());
    }

    @Override // hex.genmodel.algos.tree.TreeSHAPPredictor
    public float[] calculateContributions(R r, float[] fArr, int i, int i2, TreeSHAPPredictor.Workspace workspace) {
        if (i == 0) {
            int length = fArr.length - 1;
            fArr[length] = fArr[length] + this.expectedTreeValue;
        }
        PathPointer pathPointer = (PathPointer) workspace;
        pathPointer.reset();
        treeShap(r, fArr, this.nodes[this.rootNodeId], this.stats[this.rootNodeId], 0, pathPointer, 1.0f, 1.0f, -1, i, i2, 1.0f);
        return fArr;
    }

    @Override // hex.genmodel.algos.tree.TreeSHAPPredictor
    public PathPointer makeWorkspace() {
        PathElement[] pathElementArr = new PathElement[getWorkspaceSize()];
        for (int i = 0; i < pathElementArr.length; i++) {
            pathElementArr[i] = new PathElement();
        }
        return new PathPointer(pathElementArr);
    }

    @Override // hex.genmodel.algos.tree.TreeSHAPPredictor
    public int getWorkspaceSize() {
        int treeDepth = treeDepth() + 2;
        return (treeDepth * (treeDepth + 1)) / 2;
    }

    private int treeDepth() {
        return nodeDepth(this.nodes, 0);
    }

    private static <N extends INode> int nodeDepth(N[] nArr, int i) {
        N n = nArr[i];
        if (n.isLeaf()) {
            return 1;
        }
        return 1 + Math.max(nodeDepth(nArr, n.getLeftChildIndex()), nodeDepth(nArr, n.getRightChildIndex()));
    }

    private float treeMeanValue() {
        return nodeMeanValue(this.nodes, this.stats, 0);
    }

    private static <N extends INode, S extends INodeStat> float nodeMeanValue(N[] nArr, S[] sArr, int i) {
        N n = nArr[i];
        return n.isLeaf() ? n.getLeafValue() : ((sArr[n.getLeftChildIndex()].getWeight() * nodeMeanValue(nArr, sArr, n.getLeftChildIndex())) + (sArr[n.getRightChildIndex()].getWeight() * nodeMeanValue(nArr, sArr, n.getRightChildIndex()))) / sArr[i].getWeight();
    }
}
