package water.rapids.ast.prims.models;

import hex.AUC2;
import hex.Model;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.HypergeometricDistribution;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.stat.inference.GTest;
import water.DKV;
import water.Key;
import water.MRTask;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValMapFrame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:water/rapids/ast/prims/models/AstFairnessMetrics.class */
public class AstFairnessMetrics extends AstPrimitive {

    /* loaded from: input_file:water/rapids/ast/prims/models/AstFairnessMetrics$FairnessMRTask.class */
    public static class FairnessMRTask extends MRTask {
        public static final int GTEST_THRESHOLD = 10000;
        public static final double FISHER_TEST_REL_ERROR = 1.0000001d;
        int[] protectedColsIdx;
        int[] cardinalities;
        int responseIdx;
        int predictionIdx;
        final int tpIdx = 0;
        final int tnIdx = 1;
        final int fpIdx = 2;
        final int fnIdx = 3;
        final int llsIdx = 4;
        final int essentialMetrics = 5;
        final int maxIndex;
        final int favourableClass;
        int[] _results;
        AUC2.AUCBuilder[] _aucs;
        static final /* synthetic */ boolean $assertionsDisabled;

        public FairnessMRTask(int[] iArr, int[] iArr2, int i, int i2, int i3) {
            this.protectedColsIdx = iArr;
            this.cardinalities = iArr2;
            this.responseIdx = i;
            this.predictionIdx = i2;
            this.favourableClass = i3;
            double asDouble = Arrays.stream(iArr2).asDoubleStream().reduce((d, d2) -> {
                return d * d2;
            }).getAsDouble();
            if (asDouble > 2.147483647E9d) {
                throw new RuntimeException("Too many combinations of categories! Maximum number of category combinations is 2147483647!");
            }
            this.maxIndex = (int) asDouble;
        }

        private int pColsToKey(Chunk[] chunkArr, int i) {
            int[] iArr = new int[this.protectedColsIdx.length];
            for (int i2 = 0; i2 < this.protectedColsIdx.length; i2++) {
                if (chunkArr[this.protectedColsIdx[i2]].isNA(i)) {
                    iArr[i2] = this.cardinalities[i2] - 1;
                } else {
                    iArr[i2] = (int) (iArr[r1] + chunkArr[this.protectedColsIdx[i2]].at8(i));
                }
            }
            return pColsToKey(iArr);
        }

        public int pColsToKey(int[] iArr) {
            int i = 0;
            int i2 = 1;
            for (int i3 = 0; i3 < this.protectedColsIdx.length; i3++) {
                i += iArr[i3] * i2;
                i2 *= this.cardinalities[i3];
            }
            return i;
        }

        private double[] keyToPCols(int i) {
            double[] dArr = new double[this.cardinalities.length];
            for (int i2 = 0; i2 < this.cardinalities.length; i2++) {
                int i3 = i % this.cardinalities[i2];
                i /= this.cardinalities[i2];
                if (i3 == this.cardinalities[i2] - 1) {
                    dArr[i2] = Double.NaN;
                } else {
                    dArr[i2] = i3;
                }
            }
            return dArr;
        }

        protected String keyToString(int i, Frame frame) {
            double[] keyToPCols = keyToPCols(i);
            StringBuilder sb = new StringBuilder();
            for (int i2 = 0; i2 < this.protectedColsIdx.length; i2++) {
                if (i2 > 0) {
                    sb.append(",");
                }
                if (Double.isFinite(keyToPCols[i2])) {
                    sb.append(frame.vec(this.protectedColsIdx[i2]).domain()[(int) keyToPCols[i2]]);
                } else {
                    sb.append("NaN");
                }
            }
            return sb.toString().replaceAll("[^A-Za-z0-9,]", "_");
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && this._results != null) {
                throw new AssertionError();
            }
            this._results = new int[this.maxIndex * 5];
            this._aucs = new AUC2.AUCBuilder[this.maxIndex];
            for (int i = 0; i < chunkArr[0]._len; i++) {
                int pColsToKey = pColsToKey(chunkArr, i);
                long at8 = this.favourableClass == 1 ? chunkArr[this.responseIdx].at8(i) : 1 - chunkArr[this.responseIdx].at8(i);
                long at82 = this.favourableClass == 1 ? chunkArr[this.predictionIdx].at8(i) : 1 - chunkArr[this.predictionIdx].at8(i);
                double atd = this.favourableClass == 1 ? chunkArr[this.predictionIdx + 2].atd(i) : chunkArr[this.predictionIdx + 1].atd(i);
                if (at8 == at82) {
                    if (at8 == 1) {
                        int[] iArr = this._results;
                        int i2 = (5 * pColsToKey) + 0;
                        iArr[i2] = iArr[i2] + 1;
                    } else {
                        int[] iArr2 = this._results;
                        int i3 = (5 * pColsToKey) + 1;
                        iArr2[i3] = iArr2[i3] + 1;
                    }
                } else if (at82 == 1) {
                    int[] iArr3 = this._results;
                    int i4 = (5 * pColsToKey) + 2;
                    iArr3[i4] = iArr3[i4] + 1;
                } else {
                    int[] iArr4 = this._results;
                    int i5 = (5 * pColsToKey) + 3;
                    iArr4[i5] = iArr4[i5] + 1;
                }
                this._results[(5 * pColsToKey) + 4] = (int) (r0[r1] + (-((at8 * Math.log(atd)) + ((1 - at8) * Math.log(1.0d - atd)))));
                if (this._aucs[pColsToKey] == null) {
                    this._aucs[pColsToKey] = new AUC2.AUCBuilder(400);
                }
                this._aucs[pColsToKey].perRow(atd, (int) at8, 1.0d);
            }
        }

        @Override // water.MRTask
        public void reduce(MRTask mRTask) {
            FairnessMRTask fairnessMRTask = (FairnessMRTask) mRTask;
            if (this._results == fairnessMRTask._results) {
                return;
            }
            for (int i = 0; i < this._results.length; i++) {
                int[] iArr = this._results;
                int i2 = i;
                iArr[i2] = iArr[i2] + fairnessMRTask._results[i];
            }
            for (int i3 = 0; i3 < this.maxIndex; i3++) {
                if (this._aucs[i3] == null) {
                    this._aucs[i3] = fairnessMRTask._aucs[i3];
                } else if (fairnessMRTask._aucs[i3] != null) {
                    this._aucs[i3].reduce(fairnessMRTask._aucs[i3]);
                }
            }
        }

        public Frame getMetrics(String[] strArr, Frame frame, Model model, String[] strArr2, String str) {
            FairnessMetrics[] fairnessMetricsArr = new FairnessMetrics[this.maxIndex];
            long numRows = frame.numRows();
            for (int i = 0; i < this.maxIndex; i++) {
                fairnessMetricsArr[i] = new FairnessMetrics(this._results[(i * 5) + 0], this._results[(i * 5) + 1], this._results[(i * 5) + 2], this._results[(i * 5) + 3], this._results[(i * 5) + 4], this._aucs[i], numRows);
            }
            int i2 = 0;
            if (strArr2 != null) {
                int[] iArr = new int[strArr.length];
                for (int i3 = 0; i3 < strArr.length; i3++) {
                    iArr[i3] = ArrayUtils.find(frame.vec(strArr[i3]).domain(), strArr2[i3]);
                }
                i2 = pColsToKey(iArr);
            } else {
                double d = 0.0d;
                for (int i4 = 0; i4 < this.maxIndex; i4++) {
                    if (fairnessMetricsArr[i4].total > d) {
                        d = fairnessMetricsArr[i4].total;
                        i2 = i4;
                    }
                }
            }
            int i5 = 0;
            for (FairnessMetrics fairnessMetrics : fairnessMetricsArr) {
                i5 += fairnessMetrics.total == CMAESOptimizer.DEFAULT_STOPFITNESS ? 1 : 0;
            }
            String[] strArr3 = {"total", "relativeSize"};
            Field[] declaredFields = FairnessMetrics.class.getDeclaredFields();
            int length = strArr.length;
            int length2 = declaredFields.length + (declaredFields.length - strArr3.length) + 1;
            double[][] dArr = new double[length + length2][fairnessMetricsArr.length - i5];
            FairnessMetrics fairnessMetrics2 = fairnessMetricsArr[i2];
            int i6 = 0;
            for (int i7 = 0; i7 < this.maxIndex; i7++) {
                if (fairnessMetricsArr[i7].total != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    int i8 = 0;
                    double[] keyToPCols = keyToPCols(i7);
                    for (int i9 = 0; i9 < strArr.length; i9++) {
                        dArr[i9][i6] = keyToPCols[i9];
                    }
                    for (int i10 = 0; i10 < declaredFields.length; i10++) {
                        try {
                            dArr[length + i10][i6] = declaredFields[i10].getDouble(fairnessMetricsArr[i7]);
                            if (ArrayUtils.contains(strArr3, declaredFields[i10].getName())) {
                                i8++;
                            } else {
                                dArr[((length + declaredFields.length) + i10) - i8][i6] = declaredFields[i10].getDouble(fairnessMetricsArr[i7]) / declaredFields[i10].getDouble(fairnessMetrics2);
                            }
                        } catch (IllegalAccessException e) {
                            throw new RuntimeException(e);
                        }
                    }
                    try {
                        dArr[dArr.length - 1][i6] = getPValue(fairnessMetrics2, fairnessMetricsArr[i7]);
                    } catch (Exception e2) {
                        dArr[dArr.length - 1][i6] = Double.NaN;
                    }
                    i6++;
                }
            }
            String[] strArr4 = new String[length + length2];
            System.arraycopy(strArr, 0, strArr4, 0, strArr.length);
            int i11 = 0;
            for (int i12 = 0; i12 < declaredFields.length; i12++) {
                strArr4[length + i12] = declaredFields[i12].getName();
                if (ArrayUtils.contains(strArr3, declaredFields[i12].getName())) {
                    i11++;
                } else {
                    strArr4[((length + declaredFields.length) + i12) - i11] = "AIR_" + declaredFields[i12].getName();
                }
            }
            strArr4[strArr4.length - 1] = "p.value";
            Vec[] vecArr = new Vec[length + length2];
            for (int i13 = 0; i13 < length; i13++) {
                vecArr[i13] = Vec.makeVec(dArr[i13], frame.domains()[this.protectedColsIdx[i13]], Vec.newKey());
            }
            for (int i14 = 0; i14 < length2; i14++) {
                vecArr[length + i14] = Vec.makeVec(dArr[length + i14], Vec.newKey());
            }
            return new Frame(Key.make("fairness_metrics_" + str + "_for_model_" + model._key), strArr4, vecArr);
        }

        public Map<String, Frame> getROCInfo(Model model, Frame frame, String str) {
            HashMap hashMap = new HashMap();
            for (int i = 0; i < this.maxIndex; i++) {
                if (this._aucs[i] != null) {
                    AUC2 auc2 = new AUC2(this._aucs[i]);
                    String[] strArr = new String[auc2._nBins];
                    for (int i2 = 0; i2 < auc2._nBins; i2++) {
                        strArr[i2] = Double.toString(auc2._ths[i2]);
                    }
                    AUC2.ThresholdCriterion[] thresholdCriterionArr = AUC2.ThresholdCriterion.VALUES;
                    String[] strArr2 = new String[thresholdCriterionArr.length + 2];
                    String[] strArr3 = new String[thresholdCriterionArr.length + 2];
                    String[] strArr4 = new String[thresholdCriterionArr.length + 2];
                    strArr2[0] = "Threshold";
                    strArr3[0] = "double";
                    strArr4[0] = "%f";
                    int i3 = 0;
                    while (i3 < thresholdCriterionArr.length) {
                        strArr2[i3 + 1] = thresholdCriterionArr[i3].toString();
                        strArr3[i3 + 1] = thresholdCriterionArr[i3]._isInt ? "long" : "double";
                        strArr4[i3 + 1] = thresholdCriterionArr[i3]._isInt ? "%d" : "%f";
                        i3++;
                    }
                    strArr2[i3 + 1] = "idx";
                    strArr3[i3 + 1] = "int";
                    strArr4[i3 + 1] = "%d";
                    TwoDimTable twoDimTable = new TwoDimTable("Metrics for Thresholds", "Binomial metrics as a function of classification thresholds", new String[auc2._nBins], strArr2, strArr3, strArr4, null);
                    for (int i4 = 0; i4 < auc2._nBins; i4++) {
                        twoDimTable.set(i4, 0, Double.valueOf(strArr[i4]));
                        int i5 = 0;
                        while (i5 < thresholdCriterionArr.length) {
                            double exec = thresholdCriterionArr[i5].exec(auc2, i4);
                            twoDimTable.set(i4, 1 + i5, thresholdCriterionArr[i5]._isInt ? Long.valueOf((long) exec) : Double.valueOf(exec));
                            i5++;
                        }
                        twoDimTable.set(i4, 1 + i5, Integer.valueOf(i4));
                    }
                    String keyToString = keyToString(i, frame);
                    Frame asFrame = twoDimTable.asFrame(Key.make("thresholds_and_metrics_" + keyToString + "_for_model_" + model._key + "_for_frame_" + str));
                    DKV.put(asFrame);
                    hashMap.put("thresholds_and_metrics_" + keyToString, asFrame);
                }
            }
            return hashMap;
        }

        private static double fishersTest(long j, long j2, long j3, long j4) {
            long j5 = j + j2 + j3 + j4;
            if (j5 > 2147483647L) {
                return Double.NaN;
            }
            HypergeometricDistribution hypergeometricDistribution = new HypergeometricDistribution((int) j5, (int) (j + j2), (int) (j + j3));
            double probability = hypergeometricDistribution.probability((int) j);
            double d = 0.0d;
            for (int max = (int) Math.max(j - j4, 0L); max <= Math.min(j + j2, j + j3); max++) {
                double probability2 = hypergeometricDistribution.probability(max);
                if (probability2 <= probability * 1.0000001d) {
                    d += probability2;
                }
            }
            return d;
        }

        private static double getPValue(FairnessMetrics fairnessMetrics, FairnessMetrics fairnessMetrics2) {
            long j = (long) fairnessMetrics2.selected;
            long j2 = (long) fairnessMetrics.selected;
            long j3 = (long) (fairnessMetrics2.total - fairnessMetrics2.selected);
            long j4 = (long) (fairnessMetrics.total - fairnessMetrics.selected);
            return ((fairnessMetrics.total < 10000.0d && fairnessMetrics2.total < 10000.0d) || j == 0 || j2 == 0 || j3 == 0 || j4 == 0) ? fishersTest(j, j2, j3, j4) : new GTest().gTestDataSetsComparison(new long[]{j, j3}, new long[]{j2, j4});
        }

        static {
            $assertionsDisabled = !AstFairnessMetrics.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:water/rapids/ast/prims/models/AstFairnessMetrics$FairnessMetrics.class */
    public static class FairnessMetrics {
        double tp;
        double fp;
        double tn;
        double fn;
        double total;
        double relativeSize;
        double accuracy;
        double precision;
        double f1;
        double tpr;
        double tnr;
        double fpr;
        double fnr;
        double auc;
        double aucpr;
        double gini;
        double selected;
        double selectedRatio;
        double logloss;

        public FairnessMetrics(double d, double d2, double d3, double d4, double d5, AUC2.AUCBuilder aUCBuilder, double d6) {
            this.tp = d;
            this.tn = d2;
            this.fp = d3;
            this.fn = d4;
            this.total = d + d3 + d2 + d4;
            this.logloss = d5 / this.total;
            this.relativeSize = this.total / d6;
            this.accuracy = (d + d2) / this.total;
            this.precision = d / (d3 + d);
            this.f1 = (2.0d * d) / (((2.0d * d) + d3) + d4);
            this.tpr = d / (d + d4);
            this.tnr = d2 / (d2 + d3);
            this.fpr = d3 / (d3 + d2);
            this.fnr = d4 / (d4 + d);
            if (aUCBuilder != null) {
                AUC2 auc2 = new AUC2(aUCBuilder);
                this.auc = auc2._auc;
                this.aucpr = auc2._pr_auc;
                this.gini = auc2._gini;
            } else {
                this.auc = Double.NaN;
                this.aucpr = Double.NaN;
                this.gini = Double.NaN;
            }
            this.selected = d + d3;
            this.selectedRatio = (d + d3) / this.total;
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"model", "test_frame", "protected_columns", "reference", "favourable_class"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 6;
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "fairnessMetrics";
    }

    @Override // water.rapids.ast.AstPrimitive
    public ValMapFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Model model = stackHelp.track(astRootArr[1].exec(env)).getModel();
        Frame track = stackHelp.track(astRootArr[2].exec(env).getFrame());
        String[] strs = stackHelp.track(astRootArr[3].exec(env)).getStrs();
        String[] strs2 = stackHelp.track(astRootArr[4].exec(env)).getStrs();
        String str = stackHelp.track(astRootArr[5].exec(env)).getStr();
        String str2 = astRootArr[2].str();
        int find = track.find(model._parms._response_column);
        if (!model._output.isBinomialClassifier()) {
            throw new H2OIllegalArgumentException("Model has to be a binomial model!");
        }
        for (String str3 : strs) {
            if (track.find(str3) == -1) {
                throw new RuntimeException(str3 + " was not found in the frame!");
            }
            if (!track.vec(str3).isCategorical()) {
                throw new H2OIllegalArgumentException(str3 + " has to be a categorical column!");
            }
        }
        if (strs2.length != strs.length) {
            strs2 = null;
        } else {
            for (int i = 0; i < strs.length; i++) {
                if (!ArrayUtils.contains(track.vec(strs[i]).domain(), strs2[i])) {
                    throw new RuntimeException("Reference group is not present in the protected column");
                }
            }
        }
        if (!ArrayUtils.contains(track.vec(find).domain(), str)) {
            throw new RuntimeException("Favourable class is not present in the response!");
        }
        int find2 = ArrayUtils.find(track.vec(find).domain(), str);
        int[] find3 = track.find(strs);
        int[] array = IntStream.of(find3).map(i2 -> {
            return track.vec(i2).cardinality() + 1;
        }).toArray();
        if (Arrays.stream(array).asDoubleStream().reduce((d, d2) -> {
            return d * d2;
        }).orElse(Double.MAX_VALUE) > 1000000.0d) {
            throw new RuntimeException("Too many combinations of categories! Maximum number of category combinations is 1e6.");
        }
        Frame add = new Frame(track).add(model.score(track));
        DKV.put(add);
        try {
            FairnessMRTask fairnessMRTask = (FairnessMRTask) new FairnessMRTask(find3, array, find, track.numCols(), find2).doAll(add);
            Frame metrics = fairnessMRTask.getMetrics(strs, track, model, strs2, str2);
            Map<String, Frame> rOCInfo = fairnessMRTask.getROCInfo(model, track, str2);
            DKV.put(metrics);
            rOCInfo.put("overview", metrics);
            ValMapFrame valMapFrame = new ValMapFrame(rOCInfo);
            DKV.remove(add.getKey());
            return valMapFrame;
        } catch (Throwable th) {
            DKV.remove(add.getKey());
            throw th;
        }
    }
}
