package water.test.util;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.junit.Ignore;

@Ignore
/* loaded from: input_file:water/test/util/NaiveTreeSHAP.class */
public class NaiveTreeSHAP<R, N extends INode<R>, S extends INodeStat> {
    private final int rootNodeId;
    private final N[] nodes;
    private final S[] stats;

    public NaiveTreeSHAP(N[] nArr, S[] sArr, int i) {
        this.rootNodeId = i;
        this.nodes = nArr;
        this.stats = sArr;
    }

    public double calculateContributions(R r, double[] dArr) {
        Set<Integer> usedFeatures = usedFeatures();
        int size = usedFeatures.size();
        int length = dArr.length - 1;
        dArr[length] = dArr[length] + treeMeanValue();
        HashMap hashMap = new HashMap();
        for (Set<Integer> set : allSubsets(usedFeatures)) {
            hashMap.put(set, Double.valueOf(expValue(r, set)));
        }
        for (Integer num : usedFeatures) {
            for (Set set2 : hashMap.keySet()) {
                if (set2.contains(num)) {
                    HashSet hashSet = new HashSet(set2);
                    hashSet.remove(num);
                    double fact = ((fact(hashSet.size()) * fact(size - set2.size())) / fact(size)) * (((Double) hashMap.get(set2)).doubleValue() - ((Double) hashMap.get(hashSet)).doubleValue());
                    int intValue = num.intValue();
                    dArr[intValue] = dArr[intValue] + fact;
                }
            }
        }
        return expValue(r, usedFeatures);
    }

    private double expValue(R r, Set<Integer> set) {
        return expValue(this.rootNodeId, r, set, 1.0d);
    }

    private static int fact(int i) {
        int i2 = 1;
        for (int i3 = 1; i3 <= i; i3++) {
            i2 *= i3;
        }
        return i2;
    }

    private static List<Set<Integer>> allSubsets(Set<Integer> set) {
        LinkedList linkedList = new LinkedList();
        Integer[] numArr = (Integer[]) set.toArray(new Integer[0]);
        for (int i = 0; i < (1 << numArr.length); i++) {
            HashSet hashSet = new HashSet(set.size());
            int i2 = 1;
            for (Integer num : numArr) {
                if ((i & i2) > 0) {
                    hashSet.add(num);
                }
                i2 <<= 1;
            }
            linkedList.add(hashSet);
        }
        return linkedList;
    }

    private Set<Integer> usedFeatures() {
        HashSet hashSet = new HashSet();
        for (N n : this.nodes) {
            hashSet.add(Integer.valueOf(n.getSplitIndex()));
        }
        return hashSet;
    }

    private double expValue(int i, R r, Set<Integer> set, double d) {
        N n = this.nodes[i];
        if (n.isLeaf()) {
            return d * n.getLeafValue();
        }
        if (set.contains(Integer.valueOf(n.getSplitIndex()))) {
            return expValue(n.next(r), r, set, d);
        }
        double weight = this.stats[i].getWeight();
        return expValue(n.getLeftChildIndex(), r, set, (d * this.stats[n.getLeftChildIndex()].getWeight()) / weight) + expValue(n.getRightChildIndex(), r, set, (d * this.stats[n.getRightChildIndex()].getWeight()) / weight);
    }

    private double treeMeanValue() {
        return nodeMeanValue(this.rootNodeId);
    }

    private double nodeMeanValue(int i) {
        N n = this.nodes[i];
        return n.isLeaf() ? n.getLeafValue() : ((this.stats[n.getLeftChildIndex()].getWeight() * nodeMeanValue(n.getLeftChildIndex())) + (this.stats[n.getRightChildIndex()].getWeight() * nodeMeanValue(n.getRightChildIndex()))) / this.stats[i].getWeight();
    }
}
