package hex;

import hex.AUC2;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import org.junit.Assert;
import org.junit.Test;
import water.util.Log;

/* loaded from: input_file:hex/AUCBuilderTest.class */
public class AUCBuilderTest {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/AUCBuilderTest$RandArrayGen.class */
    public static abstract class RandArrayGen {
        Random _r;

        private RandArrayGen(long j) {
            this._r = new Random(j);
        }

        abstract void fillRandomVals(double[] dArr);
    }

    @Test
    public void testPerRow() {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(10);
        for (int i = 0; i < 100; i++) {
            aUCBuilder.perRow(i / 100.0d, 1, 1.0d);
        }
        double[] dArr = new double[10];
        System.arraycopy(aUCBuilder._ths, 0, dArr, 0, dArr.length);
        Assert.assertArrayEquals(new double[]{0.05d, 0.16d, 0.25d, 0.335d, 0.445d, 0.555d, 0.655d, 0.76d, 0.875d, 0.965d}, dArr, 1.0E-5d);
    }

    @Test
    public void testPerRow_compat() throws Exception {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(400);
        AUC2.AUCBuilder aUCBuilder2 = new AUC2.AUCBuilder(400, false);
        long j = 0;
        long j2 = 0;
        GZIPInputStream gZIPInputStream = new GZIPInputStream(AUCBuilderTest.class.getResourceAsStream("aucbuilder.csv.gz"));
        Throwable th = null;
        try {
            try {
                BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(gZIPInputStream));
                int i = 0;
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    String[] split = readLine.split(",");
                    double parseDouble = Double.parseDouble(split[0]);
                    int parseInt = Integer.parseInt(split[1]);
                    long currentTimeMillis = System.currentTimeMillis();
                    aUCBuilder.perRow(parseDouble, parseInt, 1.0d);
                    j += System.currentTimeMillis() - currentTimeMillis;
                    long currentTimeMillis2 = System.currentTimeMillis();
                    aUCBuilder2.perRow(parseDouble, parseInt, 1.0d);
                    j2 += System.currentTimeMillis() - currentTimeMillis2;
                    for (int i2 = 0; i2 < 400; i2++) {
                        Assert.assertEquals("Error in ths, line: " + i, aUCBuilder._ths[i2], aUCBuilder2._ths[i2], 0.0d);
                        Assert.assertEquals("Error in tps, line: " + i, aUCBuilder._tps[i2], aUCBuilder2._tps[i2], 0.0d);
                        Assert.assertEquals("Error in tps, line: " + i, aUCBuilder._fps[i2], aUCBuilder2._fps[i2], 0.0d);
                        Assert.assertEquals("Error in sqe, line: " + i, aUCBuilder._sqe[i2], aUCBuilder2._sqe[i2], 0.0d);
                    }
                    i++;
                }
                if (gZIPInputStream != null) {
                    if (0 != 0) {
                        try {
                            gZIPInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        gZIPInputStream.close();
                    }
                }
                System.out.println("Total time with speedup: " + j + "ms; orginal time: " + j2 + "ms.");
            } finally {
            }
        } catch (Throwable th3) {
            if (gZIPInputStream != null) {
                if (th != null) {
                    try {
                        gZIPInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    gZIPInputStream.close();
                }
            }
            throw th3;
        }
    }

    @Test
    public void testBinningQuality() {
        Assert.assertEquals(testHisto(new RandArrayGen(42L) { // from class: hex.AUCBuilderTest.2
            @Override // hex.AUCBuilderTest.RandArrayGen
            void fillRandomVals(double[] dArr) {
                for (int i = 0; i < dArr.length / 2; i++) {
                    dArr[i] = this._r.nextDouble() / 2.0d;
                }
                for (int length = dArr.length / 2; length < dArr.length; length++) {
                    dArr[length] = 0.5d + (this._r.nextDouble() / 2.0d);
                }
            }
        }, 10), 5.0d * testHisto(new RandArrayGen(42L) { // from class: hex.AUCBuilderTest.1
            @Override // hex.AUCBuilderTest.RandArrayGen
            void fillRandomVals(double[] dArr) {
                for (int i = 0; i < dArr.length; i++) {
                    dArr[i] = this._r.nextDouble();
                }
            }
        }, 10), 1.0d);
    }

    private double testHisto(RandArrayGen randArrayGen, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d += testHisto(randArrayGen);
        }
        return d / i;
    }

    private double testHisto(RandArrayGen randArrayGen) {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(11);
        double[] dArr = new double[1000];
        randArrayGen.fillRandomVals(dArr);
        for (double d : dArr) {
            aUCBuilder.perRow(d, 1, 1.0d);
        }
        return histoRMSE(aUCBuilder, dArr);
    }

    private double histoRMSE(AUC2.AUCBuilder aUCBuilder, double[] dArr) {
        double[] copyOf = Arrays.copyOf(aUCBuilder._ths, aUCBuilder._nBins);
        int[] iArr = new int[copyOf.length];
        for (int i = 0; i < copyOf.length; i++) {
            iArr[i] = (int) aUCBuilder._tps[i];
        }
        int[] iArr2 = new int[copyOf.length];
        for (double d : dArr) {
            int binarySearch = Arrays.binarySearch(copyOf, d);
            if (binarySearch >= 0) {
                iArr2[binarySearch] = iArr2[binarySearch] + 1;
            } else {
                int i2 = (-binarySearch) - 1;
                if (i2 == iArr2.length) {
                    i2 = iArr2.length - 1;
                } else if (i2 > 0 && d - copyOf[i2 - 1] < copyOf[i2] - d) {
                    i2--;
                }
                int i3 = i2;
                iArr2[i3] = iArr2[i3] + 1;
            }
        }
        double d2 = 0.0d;
        for (int i4 = 0; i4 < iArr.length; i4++) {
            d2 += Math.pow(iArr[i4] - iArr2[i4], 2.0d);
        }
        System.out.println("Actual  : " + Arrays.toString(iArr));
        System.out.println("Expected: " + Arrays.toString(iArr2));
        double sqrt = Math.sqrt(d2) / iArr.length;
        System.out.println("RMSE    : " + sqrt);
        return sqrt;
    }

    @Test
    public void testPubDev6399ReduceDoesntUsePreviousKnownSSX() {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(10);
        for (int i = 0; i < aUCBuilder._nBins; i++) {
            aUCBuilder.perRow(i, 1, 1.0d);
        }
        aUCBuilder.perRow(5.5d, 1, 1.0d);
        Assert.assertEquals(0L, aUCBuilder._ssx);
        AUC2.AUCBuilder aUCBuilder2 = new AUC2.AUCBuilder(10);
        aUCBuilder2.perRow(9.0d, 1, 1.0d);
        double[] copyOf = Arrays.copyOf(aUCBuilder._ths, aUCBuilder._n);
        aUCBuilder.reduce(aUCBuilder2);
        Assert.assertArrayEquals(copyOf, Arrays.copyOf(aUCBuilder._ths, aUCBuilder._n), 0.0d);
    }

    @Test
    public void testLargeWeights() {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(2);
        aUCBuilder.perRow(0.0d, 1, Double.MAX_VALUE);
        aUCBuilder.perRow(0.3d, 1, Double.MAX_VALUE);
        aUCBuilder.perRow(0.7d, 1, Double.MAX_VALUE);
        aUCBuilder.perRow(1.0d, 1, Double.MAX_VALUE);
        Assert.assertArrayEquals(new double[]{0.0d, 0.85d}, Arrays.copyOf(aUCBuilder._ths, 2), 0.0d);
    }

    @Test
    public void testCombineCenters() {
        double sqrt = Math.sqrt(Double.MAX_VALUE);
        Assert.assertEquals(sqrt, AUC2.AUCBuilder.combine_centers(sqrt, sqrt, sqrt, sqrt), 0.0d);
    }

    @Test
    public void testCMfor01() {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(2);
        aUCBuilder.perRow(0.0d, 0, 1.0d);
        aUCBuilder.perRow(0.0d, 1, 2.0d);
        aUCBuilder.perRow(1.0d, 0, 3.0d);
        aUCBuilder.perRow(1.0d, 1, 4.0d);
        double[][] defaultCM = AUC2.make01AUC(aUCBuilder).defaultCM();
        Log.debug(new Object[]{new ConfusionMatrix(defaultCM, new String[]{"0", "1"}).toASCII()});
        Assert.assertArrayEquals(new double[]{1.0d, 3.0d}, defaultCM[0], 0.0d);
        Assert.assertArrayEquals(new double[]{2.0d, 4.0d}, defaultCM[1], 0.0d);
        checkCMOneObs(0, 0);
        checkCMOneObs(0, 1);
        checkCMOneObs(1, 0);
        checkCMOneObs(1, 1);
    }

    private static void checkCMOneObs(int i, int i2) {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(2);
        aUCBuilder.perRow(i, i2, 1.0d);
        double[][] defaultCM = AUC2.make01AUC(aUCBuilder).defaultCM();
        Log.debug(new Object[]{"pred = " + i + "; act = " + i2});
        Log.debug(new Object[]{new ConfusionMatrix(defaultCM, new String[]{"0", "1"}).toASCII()});
        Assert.assertEquals(1.0d, defaultCM[i2][i], 0.0d);
    }

    @Test
    public void testRestrictToMaxCriterion() {
        AUC2.AUCBuilder aUCBuilder = new AUC2.AUCBuilder(10);
        aUCBuilder.perRow(0.0d, 1, 1.0d);
        aUCBuilder.perRow(0.2d, 0, 1.0d);
        aUCBuilder.perRow(0.4d, 0, 1.0d);
        aUCBuilder.perRow(0.6d, 1, 1.0d);
        aUCBuilder.perRow(0.8d, 1, 1.0d);
        aUCBuilder.perRow(1.0d, 1, 1.0d);
        AUC2 auc2 = new AUC2(aUCBuilder);
        Assert.assertEquals(2L, auc2._max_idx);
        Assert.assertEquals(0.6d, auc2.defaultThreshold(), 0.0d);
        Assert.assertEquals(0.17d, auc2.defaultErr(), 0.01d);
        AUC2 restrictToMaxCriterion = auc2.restrictToMaxCriterion();
        Assert.assertEquals(0L, restrictToMaxCriterion._max_idx);
        Assert.assertEquals(0.6d, restrictToMaxCriterion.defaultThreshold(), 0.0d);
        Assert.assertEquals(0.17d, restrictToMaxCriterion.defaultErr(), 0.01d);
        Assert.assertEquals(auc2._auc, restrictToMaxCriterion._auc, 0.0d);
        Assert.assertEquals(auc2._pr_auc, restrictToMaxCriterion._pr_auc, 0.0d);
        Assert.assertEquals(auc2._gini, restrictToMaxCriterion._gini, 0.0d);
        for (AUC2.ThresholdCriterion thresholdCriterion : AUC2.ThresholdCriterion.values()) {
            Assert.assertEquals("Value of metric " + thresholdCriterion + " expected to be the same", thresholdCriterion.exec(auc2, auc2._max_idx), thresholdCriterion.exec(restrictToMaxCriterion, 0), 0.0d);
        }
    }
}
