package hex.tree;

import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.rapids.Val;
import water.util.ArrayUtils;
import water.util.VecUtils;

/* loaded from: input_file:hex/tree/FriedmanPopescusH.class */
public class FriedmanPopescusH {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/FriedmanPopescusH$FindFValue.class */
    public static class FindFValue extends MRTask<FindFValue> {
        double[] valueToFindFValueFor;
        String[] currNames;
        String[] currFValuesNames;
        double eps;
        public Double result;
        long resultIndex = Long.MAX_VALUE;

        FindFValue(double[] dArr, String[] strArr, String[] strArr2, double d) {
            this.valueToFindFValueFor = dArr;
            this.currNames = strArr;
            this.currFValuesNames = strArr2;
            this.eps = d;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int i = 0;
            if (chunkArr[0].start() > this.resultIndex) {
                return;
            }
            for (int i2 = 0; i2 < chunkArr[0].len(); i2++) {
                for (int i3 = 0; i3 < this.valueToFindFValueFor.length; i3++) {
                    int find = ArrayUtils.find(this.currFValuesNames, this.currNames[i3]);
                    if (Double.isNaN(this.valueToFindFValueFor[i3]) && Double.isNaN(chunkArr[find].atd(i2))) {
                        i++;
                    }
                    if (Math.abs(this.valueToFindFValueFor[i3] - chunkArr[find].atd(i2)) < this.eps) {
                        i++;
                    }
                }
                if (i == this.valueToFindFValueFor.length) {
                    if (chunkArr[0].start() + i2 < this.resultIndex) {
                        this.result = Double.valueOf(chunkArr[0].atd(i2));
                        this.resultIndex = chunkArr[0].start() + i2;
                        return;
                    }
                    return;
                }
                i = 0;
            }
        }

        @Override // water.MRTask
        public void reduce(FindFValue findFValue) {
            if (null == findFValue || null == findFValue.result || this.resultIndex <= findFValue.resultIndex) {
                return;
            }
            this.result = findFValue.result;
            this.resultIndex = findFValue.resultIndex;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/FriedmanPopescusH$Transform.class */
    public static class Transform extends MRTask<Transform> {
        double result;
        int power;

        Transform(int i) {
            this.power = i;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this.result = CMAESOptimizer.DEFAULT_STOPFITNESS;
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                this.result += Math.pow(chunkArr[0].atd(i2), 2.0d) * chunkArr[1].atd(i2);
            }
        }

        @Override // water.MRTask
        public void reduce(Transform transform) {
            this.result += transform.result;
        }
    }

    public static double h(Frame frame, String[] strArr, double d, SharedTreeSubgraph[][] sharedTreeSubgraphArr) {
        Frame filterFrame = filterFrame(frame, strArr);
        int[] modelIds = getModelIds(frame.names(), strArr);
        HashMap hashMap = new HashMap();
        int numCols = filterFrame.numCols();
        int[] iArr = new int[numCols];
        for (int i = 0; i < numCols; i++) {
            iArr[i] = i;
        }
        for (int i2 = numCols; i2 > 0; i2--) {
            List<int[]> combinations = combinations(iArr, i2);
            for (int i3 = 0; i3 < combinations.size(); i3++) {
                int[] iArr2 = combinations.get(i3);
                hashMap.put(Arrays.toString(iArr2), computeFValues(getCurrentCombinationModelIds(iArr2, modelIds), filterFrame, getCurrCombinationCols(iArr2, strArr), d, sharedTreeSubgraphArr));
            }
        }
        return computeHValue(hashMap, filterFrame, iArr);
    }

    static int[] getCurrentCombinationModelIds(int[] iArr, int[] iArr2) {
        int[] iArr3 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr3[i] = iArr2[iArr[i]];
        }
        return iArr3;
    }

    static double computeHValue(Map<String, Frame> map, Frame frame, int[] iArr) {
        if (frame._key == null) {
            frame._key = Key.make();
        }
        Frame uniqueRowsWithCounts = uniqueRowsWithCounts(frame);
        long numRows = uniqueRowsWithCounts.numRows();
        Vec makeZero = Vec.makeZero(numRows);
        Vec makeZero2 = Vec.makeZero(numRows);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= numRows) {
                break;
            }
            int i = 1;
            for (int length = iArr.length; length > 0; length--) {
                List<int[]> combinations = combinations(iArr, length);
                for (int i2 = 0; i2 < combinations.size(); i2++) {
                    makeZero.set(j2, makeZero.at(j2) + (i * ((float) findFValue(j2, (int[]) combinations.toArray()[i2], map.get(Arrays.toString((int[]) combinations.toArray()[i2])), frame))));
                }
                i *= -1;
            }
            makeZero2.set(j2, (float) map.get(Arrays.toString(iArr)).vec(0).at(j2));
            j = j2 + 1;
        }
        double d = new Transform(2).doAll(makeZero, uniqueRowsWithCounts.vec("nrow")).result;
        double d2 = new Transform(2).doAll(makeZero2, uniqueRowsWithCounts.vec("nrow")).result;
        if (d < d2) {
            return Math.sqrt(d / d2);
        }
        return Double.NaN;
    }

    static double[] getValueToFindFValueFor(int[] iArr, Frame frame, long j) {
        int length = iArr.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = frame.vec(iArr[i]).at(j);
        }
        return dArr;
    }

    static double findFValue(long j, int[] iArr, Frame frame, Frame frame2) {
        double[] valueToFindFValueFor = getValueToFindFValueFor(iArr, frame2, j);
        Double d = new FindFValue(valueToFindFValueFor, getCurrCombinationNames(iArr, frame2.names()), frame._names, 1.0E-5d).doAll(frame).result;
        if (null == d) {
            throw new RuntimeException("FValue was not found!" + Arrays.toString(iArr) + "value: " + Arrays.toString(valueToFindFValueFor));
        }
        return d.doubleValue();
    }

    static String[] getCurrCombinationNames(int[] iArr, String[] strArr) {
        String[] strArr2 = new String[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            strArr2[i] = strArr[iArr[i]];
        }
        return strArr2;
    }

    static String[] getCurrCombinationCols(int[] iArr, String[] strArr) {
        String[] strArr2 = new String[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            strArr2[i] = strArr[iArr[i]];
        }
        return strArr2;
    }

    static int findFirstNumericalColumn(Frame frame) {
        for (int i = 0; i < frame.names().length; i++) {
            if (frame.vec(i).isNumeric()) {
                return i;
            }
        }
        return -1;
    }

    static Frame uniqueRowsWithCounts(Frame frame) {
        DKV.put(frame);
        StringBuilder sb = new StringBuilder("(GB ");
        String[] names = frame.names();
        sb.append(frame._key.toString());
        sb.append(" [");
        for (int i = 0; i < names.length; i++) {
            if (i != 0) {
                sb.append(",");
            }
            sb.append(i);
        }
        sb.append("] ");
        int findFirstNumericalColumn = findFirstNumericalColumn(frame);
        if (findFirstNumericalColumn == -1) {
            frame.add("nrow", Vec.makeOne(frame.numRows()));
            return frame;
        }
        sb.append(" nrow ").append(findFirstNumericalColumn).append(" \"all\")");
        Val exec = Rapids.exec(sb.toString());
        DKV.remove(frame._key);
        return exec.getFrame();
    }

    static Frame computeFValues(int[] iArr, Frame frame, String[] strArr, double d, SharedTreeSubgraph[][] sharedTreeSubgraphArr) {
        Frame filterFrame = filterFrame(frame, strArr);
        Frame uniqueRowsWithCounts = uniqueRowsWithCounts(new Frame(Key.make(), filterFrame.names(), filterFrame.vecs()));
        Frame frame2 = new Frame(partialDependence(iArr, uniqueRowsWithCounts, d, sharedTreeSubgraphArr).vec(0));
        double numRows = new VecUtils.DotProduct().doAll(uniqueRowsWithCounts.vec("nrow"), frame2.vec(0)).result / r0.numRows();
        Vec.Writer open = frame2.vec(0).open();
        Throwable th = null;
        try {
            try {
                Vec vec = frame2.vec(0);
                vec.getClass();
                Vec.Reader reader = new Vec.Reader();
                for (int i = 0; i < frame2.numRows(); i++) {
                    open.set(i, reader.at(i) - numRows);
                }
                if (open != null) {
                    if (0 != 0) {
                        try {
                            open.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        open.close();
                    }
                }
                return frame2.add(uniqueRowsWithCounts);
            } finally {
            }
        } catch (Throwable th3) {
            if (open != null) {
                if (th != null) {
                    try {
                        open.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    open.close();
                }
            }
            throw th3;
        }
    }

    static Frame partialDependence(int[] iArr, Frame frame, double d, SharedTreeSubgraph[][] sharedTreeSubgraphArr) {
        Frame frame2 = new Frame(new Vec[0]);
        int length = sharedTreeSubgraphArr[0].length;
        for (int i = 0; i < length; i++) {
            Vec makeZero = Vec.makeZero(frame.numRows());
            for (SharedTreeSubgraph[] sharedTreeSubgraphArr2 : sharedTreeSubgraphArr) {
                Vec.Writer open = makeZero.open();
                Throwable th = null;
                try {
                    try {
                        Vec partialDependenceTree = partialDependenceTree(sharedTreeSubgraphArr2[i], iArr, d, frame);
                        partialDependenceTree.getClass();
                        Vec.Reader reader = new Vec.Reader();
                        makeZero.getClass();
                        Vec.Reader reader2 = new Vec.Reader();
                        for (long j = 0; j < frame.numRows(); j++) {
                            open.set(j, reader2.at(j) + reader.at(j));
                        }
                        if (open != null) {
                            if (0 != 0) {
                                try {
                                    open.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                open.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (open != null) {
                        if (th != null) {
                            try {
                                open.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            open.close();
                        }
                    }
                    throw th3;
                }
            }
            frame2.add("pdp_C" + i, makeZero);
        }
        return frame2;
    }

    public static double[] add(double[] dArr, double[] dArr2) {
        int min = Math.min(dArr.length, dArr2.length);
        double[] dArr3 = new double[min];
        for (int i = 0; i < min; i++) {
            dArr3[i] = dArr[i] + dArr2[i];
        }
        return dArr3;
    }

    static Frame filterFrame(Frame frame, String[] strArr) {
        Frame frame2 = new Frame(new Vec[0]);
        frame2.add(strArr, frame.vecs(strArr));
        return frame2;
    }

    static int[] getModelIds(String[] strArr, String[] strArr2) {
        int[] iArr = new int[strArr2.length];
        Arrays.fill(iArr, -1);
        for (int i = 0; i < strArr2.length; i++) {
            for (int i2 = 0; i2 < strArr.length; i2++) {
                if (strArr2[i].equals(strArr[i2])) {
                    iArr[i] = i2;
                }
            }
            if (iArr[i] == -1) {
                throw new RuntimeException("Column " + strArr2[i] + " is not present in the input frame!");
            }
        }
        return iArr;
    }

    static List<int[]> combinations(int[] iArr, int i) {
        ArrayList arrayList = new ArrayList();
        combinations(iArr, i, 0, new int[i], arrayList);
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static void combinations(int[] iArr, int i, int i2, int[] iArr2, List<int[]> list) {
        if (i == 0) {
            list.add(iArr2.clone());
            return;
        }
        for (int i3 = i2; i3 <= iArr.length - i; i3++) {
            iArr2[iArr2.length - i] = iArr[i3];
            combinations(iArr, i - 1, i3 + 1, iArr2, list);
        }
    }

    static Vec partialDependenceTree(SharedTreeSubgraph sharedTreeSubgraph, int[] iArr, double d, Frame frame) {
        Vec makeZero = Vec.makeZero(frame.numRows());
        SharedTreeNode[] sharedTreeNodeArr = new SharedTreeNode[sharedTreeSubgraph.nodesArray.size() * 2];
        Double[] dArr = new Double[sharedTreeSubgraph.nodesArray.size() * 2];
        Arrays.fill(dArr, Double.valueOf(1.0d));
        Vec.Writer open = makeZero.open();
        Throwable th = null;
        try {
            try {
                Vec.Reader[] readerArr = new Vec.Reader[frame.numCols()];
                for (int i = 0; i < frame.numCols(); i++) {
                    Vec vec = frame.vec(i);
                    vec.getClass();
                    readerArr[i] = new Vec.Reader();
                }
                for (long j = 0; j < frame.numRows(); j++) {
                    int i2 = 1;
                    sharedTreeNodeArr[0] = sharedTreeSubgraph.rootNode;
                    dArr[0] = Double.valueOf(1.0d);
                    double d2 = 0.0d;
                    double d3 = 0.0d;
                    while (i2 > 0) {
                        i2--;
                        SharedTreeNode sharedTreeNode = sharedTreeNodeArr[i2];
                        if (sharedTreeNode.isLeaf()) {
                            d3 += dArr[i2].doubleValue() * sharedTreeNode.getPredValue() * d;
                            d2 += dArr[i2].doubleValue();
                        } else {
                            int find = ArrayUtils.find(iArr, sharedTreeNode.getColId());
                            if (find >= 0) {
                                if (readerArr[find].at(j) <= sharedTreeNode.getSplitValue()) {
                                    sharedTreeNodeArr[i2] = sharedTreeNode.getLeftChild();
                                } else {
                                    sharedTreeNodeArr[i2] = sharedTreeNode.getRightChild();
                                }
                                i2++;
                            } else {
                                double doubleValue = dArr[i2].doubleValue();
                                sharedTreeNodeArr[i2] = sharedTreeNode.getLeftChild();
                                double weight = sharedTreeNode.getLeftChild().getWeight() / sharedTreeNode.getWeight();
                                dArr[i2] = Double.valueOf(doubleValue * weight);
                                int i3 = i2 + 1;
                                sharedTreeNodeArr[i3] = sharedTreeNode.getRightChild();
                                dArr[i3] = Double.valueOf(doubleValue * (1.0d - weight));
                                i2 = i3 + 1;
                            }
                        }
                    }
                    open.set(j, d3);
                    if (0.999d >= d2 || d2 >= 1.001d) {
                        throw new RuntimeException("Total weight should be 1.0 but was " + d2);
                    }
                }
                if (open != null) {
                    if (0 != 0) {
                        try {
                            open.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        open.close();
                    }
                }
                return makeZero;
            } finally {
            }
        } catch (Throwable th3) {
            if (open != null) {
                if (th != null) {
                    try {
                        open.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    open.close();
                }
            }
            throw th3;
        }
    }
}
