/*
 * Decompiled with CFR 0.152.
 */
package hex.util;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.junit.Ignore;

@Ignore
public class NaiveTreeSHAP<R, N extends INode<R>, S extends INodeStat> {
    private final int rootNodeId;
    private final N[] nodes;
    private final S[] stats;
    private final double baseMargin;

    public NaiveTreeSHAP(N[] nodes, S[] stats, int rootNodeId, double baseMargin) {
        this.rootNodeId = rootNodeId;
        this.nodes = nodes;
        this.stats = stats;
        this.baseMargin = baseMargin;
    }

    public double calculateContributions(R row, double[] contribsNaive) {
        Set<Integer> usedFeatures = this.usedFeatures();
        int M = usedFeatures.size();
        int n = contribsNaive.length - 1;
        contribsNaive[n] = contribsNaive[n] + (this.treeMeanValue() + this.baseMargin);
        HashMap<Set<Integer>, Double> expVals = new HashMap<Set<Integer>, Double>();
        for (Set<Integer> subset : NaiveTreeSHAP.allSubsets(usedFeatures)) {
            expVals.put(subset, this.expValue(row, subset));
        }
        for (Integer feature : usedFeatures) {
            for (Set subset : expVals.keySet()) {
                if (!subset.contains(feature)) continue;
                HashSet noFeature = new HashSet(subset);
                noFeature.remove(feature);
                double mult = (double)((long)NaiveTreeSHAP.fact(noFeature.size()) * (long)NaiveTreeSHAP.fact(M - subset.size())) / (double)NaiveTreeSHAP.fact(M);
                double contrib = mult * ((Double)expVals.get(subset) - (Double)expVals.get(noFeature));
                int n2 = feature;
                contribsNaive[n2] = contribsNaive[n2] + contrib;
            }
        }
        return this.expValue(row, usedFeatures);
    }

    private double expValue(R v, Set<Integer> s) {
        return this.expValue(this.rootNodeId, v, s, 1.0);
    }

    private static int fact(int v) {
        int f = 1;
        for (int i = 1; i <= v; ++i) {
            f *= i;
        }
        return f;
    }

    private static List<Set<Integer>> allSubsets(Set<Integer> s) {
        LinkedList<Set<Integer>> result = new LinkedList<Set<Integer>>();
        Integer[] ary = s.toArray(new Integer[0]);
        for (int i = 0; i < 1 << ary.length; ++i) {
            HashSet<Integer> subset = new HashSet<Integer>(s.size());
            int m = 1;
            for (Integer item : ary) {
                if ((i & m) > 0) {
                    subset.add(item);
                }
                m <<= 1;
            }
            result.add(subset);
        }
        return result;
    }

    private Set<Integer> usedFeatures() {
        HashSet<Integer> features = new HashSet<Integer>();
        for (N n : this.nodes) {
            features.add(n.getSplitIndex());
        }
        return features;
    }

    private double expValue(int node, R v, Set<Integer> s, double w) {
        N n = this.nodes[node];
        if (n.isLeaf()) {
            return w * (double)n.getLeafValue();
        }
        if (s.contains(n.getSplitIndex())) {
            return this.expValue(n.next(v), v, s, w);
        }
        double wP = this.stats[node].getWeight();
        double wL = this.stats[n.getLeftChildIndex()].getWeight();
        double wR = this.stats[n.getRightChildIndex()].getWeight();
        return this.expValue(n.getLeftChildIndex(), v, s, w * wL / wP) + this.expValue(n.getRightChildIndex(), v, s, w * wR / wP);
    }

    private double treeMeanValue() {
        return this.nodeMeanValue(this.rootNodeId);
    }

    private double nodeMeanValue(int node) {
        N n = this.nodes[node];
        if (n.isLeaf()) {
            return n.getLeafValue();
        }
        return ((double)this.stats[n.getLeftChildIndex()].getWeight() * this.nodeMeanValue(n.getLeftChildIndex()) + (double)this.stats[n.getRightChildIndex()].getWeight() * this.nodeMeanValue(n.getRightChildIndex())) / (double)this.stats[node].getWeight();
    }
}

