package biz.k11i.xgboost.tree;

import biz.k11i.xgboost.tree.RegTreeImpl;
import biz.k11i.xgboost.util.FVec;

/* loaded from: input_file:biz/k11i/xgboost/tree/TreeSHAP.class */
public class TreeSHAP {
    private final RegTreeImpl.Node[] nodes;
    private final RegTreeImpl.RTreeNodeStat[] stats;
    private final float expectedTreeValue = treeMeanValue();
    private final PathElement[] unique_path_data;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:biz/k11i/xgboost/tree/TreeSHAP$PathElement.class */
    public static class PathElement {
        int feature_index;
        float zero_fraction;
        float one_fraction;
        float pweight;

        private PathElement() {
        }

        void reset() {
            this.feature_index = 0;
            this.zero_fraction = 0.0f;
            this.one_fraction = 0.0f;
            this.pweight = 0.0f;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:biz/k11i/xgboost/tree/TreeSHAP$PathPointer.class */
    public static class PathPointer {
        PathElement[] path;
        int position;

        PathPointer(PathElement[] pathElementArr) {
            this.path = pathElementArr;
        }

        PathPointer(PathElement[] pathElementArr, int i) {
            this.path = pathElementArr;
            this.position = i;
        }

        PathElement get(int i) {
            return this.path[this.position + i];
        }

        PathPointer move(int i) {
            for (int i2 = 0; i2 < i; i2++) {
                this.path[this.position + i + i2].feature_index = this.path[this.position + i2].feature_index;
                this.path[this.position + i + i2].zero_fraction = this.path[this.position + i2].zero_fraction;
                this.path[this.position + i + i2].one_fraction = this.path[this.position + i2].one_fraction;
                this.path[this.position + i + i2].pweight = this.path[this.position + i2].pweight;
            }
            return new PathPointer(this.path, this.position + i);
        }
    }

    public TreeSHAP(RegTreeImpl regTreeImpl) {
        this.nodes = regTreeImpl.getNodes();
        this.stats = regTreeImpl.getStats();
        int treeDepth = treeDepth() + 2;
        this.unique_path_data = new PathElement[(treeDepth * (treeDepth + 1)) / 2];
        for (int i = 0; i < this.unique_path_data.length; i++) {
            this.unique_path_data[i] = new PathElement();
        }
    }

    private void extendPath(PathPointer pathPointer, int i, float f, float f2, int i2) {
        pathPointer.get(i).feature_index = i2;
        pathPointer.get(i).zero_fraction = f;
        pathPointer.get(i).one_fraction = f2;
        pathPointer.get(i).pweight = i == 0 ? 1.0f : 0.0f;
        for (int i3 = i - 1; i3 >= 0; i3--) {
            pathPointer.get(i3 + 1).pweight += ((f2 * pathPointer.get(i3).pweight) * (i3 + 1)) / (i + 1);
            pathPointer.get(i3).pweight = ((f * pathPointer.get(i3).pweight) * (i - i3)) / (i + 1);
        }
    }

    private void unwindPath(PathPointer pathPointer, int i, int i2) {
        float f = pathPointer.get(i2).one_fraction;
        float f2 = pathPointer.get(i2).zero_fraction;
        float f3 = pathPointer.get(i).pweight;
        for (int i3 = i - 1; i3 >= 0; i3--) {
            if (f != 0.0f) {
                float f4 = pathPointer.get(i3).pweight;
                pathPointer.get(i3).pweight = (f3 * (i + 1)) / ((i3 + 1) * f);
                f3 = f4 - (((pathPointer.get(i3).pweight * f2) * (i - i3)) / (i + 1));
            } else {
                pathPointer.get(i3).pweight = (pathPointer.get(i3).pweight * (i + 1)) / (f2 * (i - i3));
            }
        }
        for (int i4 = i2; i4 < i; i4++) {
            pathPointer.get(i4).feature_index = pathPointer.get(i4 + 1).feature_index;
            pathPointer.get(i4).zero_fraction = pathPointer.get(i4 + 1).zero_fraction;
            pathPointer.get(i4).one_fraction = pathPointer.get(i4 + 1).one_fraction;
        }
    }

    private float unwoundPathSum(PathPointer pathPointer, int i, int i2) {
        float f = pathPointer.get(i2).one_fraction;
        float f2 = pathPointer.get(i2).zero_fraction;
        float f3 = pathPointer.get(i).pweight;
        float f4 = 0.0f;
        for (int i3 = i - 1; i3 >= 0; i3--) {
            if (f != 0.0f) {
                float f5 = (f3 * (i + 1)) / ((i3 + 1) * f);
                f4 += f5;
                f3 = pathPointer.get(i3).pweight - ((f5 * f2) * ((i - i3) / (i + 1)));
            } else if (f2 != 0.0f) {
                f4 += (pathPointer.get(i3).pweight / f2) / ((i - i3) / (i + 1));
            } else if (pathPointer.get(i3).pweight != 0.0f) {
                throw new IllegalStateException("Unique path " + i3 + " must have zero weight");
            }
        }
        return f4;
    }

    private void treeShap(FVec fVec, float[] fArr, int i, int i2, PathPointer pathPointer, float f, float f2, int i3, int i4, int i5, float f3) {
        RegTreeImpl.Node node = this.nodes[i];
        if (f3 == 0.0f) {
            return;
        }
        PathPointer move = pathPointer.move(i2);
        if (i4 == 0 || i5 != i3) {
            extendPath(move, i2, f, f2, i3);
        }
        int split_index = node.split_index();
        if (node.is_leaf()) {
            for (int i6 = 1; i6 <= i2; i6++) {
                float unwoundPathSum = unwoundPathSum(move, i2, i6);
                PathElement pathElement = move.get(i6);
                int i7 = pathElement.feature_index;
                fArr[i7] = fArr[i7] + (unwoundPathSum * (pathElement.one_fraction - pathElement.zero_fraction) * node.leaf_value * f3);
            }
            return;
        }
        int cdefault = Float.isNaN(fVec.fvalue(split_index)) ? node.cdefault() : fVec.fvalue(split_index) < node.split_cond ? node.cleft_ : node.cright_;
        int i8 = cdefault == node.cleft_ ? node.cright_ : node.cleft_;
        float f4 = this.stats[i].sum_hess;
        float f5 = this.stats[cdefault].sum_hess / f4;
        float f6 = this.stats[i8].sum_hess / f4;
        float f7 = 1.0f;
        float f8 = 1.0f;
        int i9 = 0;
        while (i9 <= i2 && move.get(i9).feature_index != split_index) {
            i9++;
        }
        if (i9 != i2 + 1) {
            f7 = move.get(i9).zero_fraction;
            f8 = move.get(i9).one_fraction;
            unwindPath(move, i2, i9);
            i2--;
        }
        float f9 = f3;
        float f10 = f3;
        if (i4 > 0 && split_index == i5) {
            f10 = 0.0f;
            i2--;
        } else if (i4 < 0 && split_index == i5) {
            f9 *= f5;
            f10 *= f6;
            i2--;
        }
        treeShap(fVec, fArr, cdefault, i2 + 1, move, f5 * f7, f8, split_index, i4, i5, f9);
        treeShap(fVec, fArr, i8, i2 + 1, move, f6 * f7, 0.0f, split_index, i4, i5, f10);
    }

    public void calculateContributions(FVec fVec, int i, float[] fArr, int i2, int i3) {
        if (i2 == 0) {
            int length = fArr.length - 1;
            fArr[length] = fArr[length] + this.expectedTreeValue;
        }
        treeShap(fVec, fArr, i, 0, getWorkspace(), 1.0f, 1.0f, -1, i2, i3, 1.0f);
    }

    private PathPointer getWorkspace() {
        this.unique_path_data[0].reset();
        return new PathPointer(this.unique_path_data);
    }

    private int treeDepth() {
        return nodeDepth(this.nodes, 0);
    }

    private static int nodeDepth(RegTreeImpl.Node[] nodeArr, int i) {
        RegTreeImpl.Node node = nodeArr[i];
        if (node.is_leaf()) {
            return 1;
        }
        return 1 + Math.max(nodeDepth(nodeArr, node.cleft_), nodeDepth(nodeArr, node.cright_));
    }

    private float treeMeanValue() {
        return nodeMeanValue(this.nodes, this.stats, 0);
    }

    private static float nodeMeanValue(RegTreeImpl.Node[] nodeArr, RegTreeImpl.RTreeNodeStat[] rTreeNodeStatArr, int i) {
        RegTreeImpl.Node node = nodeArr[i];
        return node.is_leaf() ? node.leaf_value : ((rTreeNodeStatArr[node.cleft_].sum_hess * nodeMeanValue(nodeArr, rTreeNodeStatArr, node.cleft_)) + (rTreeNodeStatArr[node.cright_].sum_hess * nodeMeanValue(nodeArr, rTreeNodeStatArr, node.cright_))) / rTreeNodeStatArr[i].sum_hess;
    }
}
