package hex;

import hex.ScoreKeeper;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import water.TestUtil;
import water.util.Log;

/* loaded from: input_file:hex/ScoreKeeperTest.class */
public class ScoreKeeperTest extends TestUtil {
    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    private static boolean stopEarly(double[] dArr, int i, double d, boolean z, boolean z2) {
        if (dArr.length - 1 < 2 * i) {
            return false;
        }
        double[] dArr2 = new double[i + 1];
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            dArr2[i2] = 0.0d;
            int length = (dArr.length - (2 * i)) + i2;
            for (int i3 = 0; i3 < i; i3++) {
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + dArr[length + i3];
            }
            int i5 = i2;
            dArr2[i5] = dArr2[i5] / i;
        }
        if (z2) {
            Log.info(new Object[]{"JUnit: moving averages: " + Arrays.toString(dArr2)});
        }
        double d2 = dArr2[0];
        boolean z3 = false;
        for (int i6 = 1; i6 < dArr2.length; i6++) {
            z3 = z ? z3 | (dArr2[i6] > d2 * (1.0d + d)) : z3 | (dArr2[i6] < d2 * (1.0d - d));
            if (z3 && z2) {
                Log.info(new Object[]{"JUnit: improved from " + d2 + " to " + dArr2[i6] + " by at least " + d + " relative tolerance"});
            }
        }
        if (z3) {
            if (!z2) {
                return false;
            }
            Log.info(new Object[]{"JUnit: Still improving."});
            return false;
        }
        if (!z2) {
            return true;
        }
        Log.info(new Object[]{"JUnit: Stopped."});
        return true;
    }

    private static ScoreKeeper[] fillScoreKeeperArray(double[] dArr, boolean z) {
        ScoreKeeper[] scoreKeeperArr = new ScoreKeeper[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            scoreKeeperArr[i] = new ScoreKeeper();
            if (z) {
                scoreKeeperArr[i]._AUC = dArr[i];
            } else {
                scoreKeeperArr[i]._logloss = dArr[i];
            }
        }
        return scoreKeeperArr;
    }

    @Test
    public void testConvergenceScoringHistory() {
        Random random = new Random(12648430L);
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= 100) {
                return;
            }
            boolean nextBoolean = random.nextBoolean();
            ScoreKeeper.StoppingMetric stoppingMetric = nextBoolean ? ScoreKeeper.StoppingMetric.AUC : ScoreKeeper.StoppingMetric.logloss;
            double nextFloat = random.nextFloat() * 0.1d;
            int nextInt = 5 + random.nextInt(10);
            double[] dArr = new double[nextInt];
            for (int i3 = 0; i3 < nextInt; i3++) {
                dArr[i3] = (nextBoolean ? 10.0d + (i3 / nextInt) : 10.0d - (i3 / nextInt)) + (random.nextGaussian() * 0.33d);
            }
            ScoreKeeper[] fillScoreKeeperArray = fillScoreKeeperArray(dArr, nextBoolean);
            Log.info(new Object[0]);
            Log.info(new Object[]{"series: " + Arrays.toString(dArr)});
            Log.info(new Object[]{"moreIsBetter: " + nextBoolean});
            Log.info(new Object[]{"relative tolerance: " + nextFloat});
            for (int length = dArr.length - 1; length > 0; length--) {
                boolean stopEarly = stopEarly(dArr, length, nextFloat, nextBoolean, false);
                boolean stopEarly2 = ScoreKeeper.stopEarly(fillScoreKeeperArray, length, ScoreKeeper.ProblemType.classification, stoppingMetric, nextFloat, "JUnit's", false);
                if (stopEarly || stopEarly2) {
                    Log.info(new Object[]{"Stopped for k=" + length});
                }
                Assert.assertTrue("For k=" + length + ", JUnit: " + stopEarly + ", ScoreKeeper: " + stopEarly2, stopEarly == stopEarly2);
            }
        }
    }

    @Test
    public void testGridSearch() {
        Random random = new Random(912559L);
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= 100) {
                return;
            }
            final boolean nextBoolean = random.nextBoolean();
            int nextInt = 5 + random.nextInt(10);
            double nextDouble = random.nextDouble() * 0.1d;
            Double[] dArr = new Double[nextInt];
            for (int i3 = 0; i3 < nextInt; i3++) {
                dArr[i3] = Double.valueOf(10.0d + random.nextDouble());
            }
            Arrays.sort(dArr, new Comparator<Double>() { // from class: hex.ScoreKeeperTest.1
                @Override // java.util.Comparator
                public int compare(Double d, Double d2) {
                    int i4 = d.doubleValue() < d2.doubleValue() ? 1 : d.doubleValue() == d2.doubleValue() ? 0 : -1;
                    if (nextBoolean) {
                        i4 = -i4;
                    }
                    return i4;
                }
            });
            double[] dArr2 = new double[dArr.length];
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                dArr2[i4] = dArr[i4].doubleValue();
            }
            Log.info(new Object[]{"Sorted values (leaderboard) - rightmost is best: " + Arrays.toString(dArr2)});
            for (int i5 = 1; i5 < dArr2.length; i5++) {
                Log.info(new Object[]{"Testing k=" + i5});
                ScoreKeeper.StoppingMetric stoppingMetric = nextBoolean ? ScoreKeeper.StoppingMetric.AUC : ScoreKeeper.StoppingMetric.logloss;
                ScoreKeeper[] fillScoreKeeperArray = fillScoreKeeperArray(dArr2, nextBoolean);
                boolean stopEarly = stopEarly(dArr2, i5, nextDouble, nextBoolean, true);
                boolean stopEarly2 = ScoreKeeper.stopEarly(fillScoreKeeperArray, i5, ScoreKeeper.ProblemType.classification, stoppingMetric, nextDouble, "JUnit's", true);
                Assert.assertTrue("For k=" + i5 + ", JUnit: " + stopEarly + ", ScoreKeeper: " + stopEarly2, stopEarly == stopEarly2);
            }
        }
    }

    @Test
    public void testHitLowerBound() {
        Assert.assertTrue(ScoreKeeper.stopEarly(new ScoreKeeper[]{new ScoreKeeper(0.3d), new ScoreKeeper(0.2d), new ScoreKeeper(0.1d), new ScoreKeeper(0.0d), new ScoreKeeper(0.0d)}, 1, ScoreKeeper.ProblemType.regression, ScoreKeeper.StoppingMetric.MSE, 0.01d, "MSE", true));
    }

    @Test
    public void testQueryConvergenceStrategy() {
        ScoreKeeper.IConvergenceStrategy iConvergenceStrategy = (ScoreKeeper.IConvergenceStrategy) Mockito.mock(ScoreKeeper.IConvergenceStrategy.class);
        ScoreKeeper.IStoppingMetric iStoppingMetric = (ScoreKeeper.IStoppingMetric) Mockito.mock(ScoreKeeper.IStoppingMetric.class);
        Mockito.when(iStoppingMetric.getConvergenceStrategy()).thenReturn(iConvergenceStrategy);
        Mockito.when(Double.valueOf(iStoppingMetric.metricValue((ScoreKeeper) ArgumentMatchers.any(ScoreKeeper.class)))).thenAnswer(new Answer<Double>() { // from class: hex.ScoreKeeperTest.2
            /* renamed from: answer, reason: merged with bridge method [inline-methods] */
            public Double m0answer(InvocationOnMock invocationOnMock) {
                return Double.valueOf(((ScoreKeeper) invocationOnMock.getArgument(0))._mse);
            }
        });
        Mockito.when(Double.valueOf(iConvergenceStrategy.extremePoint(ArgumentMatchers.anyDouble(), ArgumentMatchers.anyDouble(), ArgumentMatchers.anyDouble()))).thenReturn(Double.valueOf(0.1d));
        ScoreKeeper.stopEarly(new ScoreKeeper[]{new ScoreKeeper(0.6d), new ScoreKeeper(0.5d), new ScoreKeeper(0.4d), new ScoreKeeper(0.3d), new ScoreKeeper(0.2d), new ScoreKeeper(0.1d)}, 2, ScoreKeeper.ProblemType.regression, iStoppingMetric, 0.01d, "Mock", true);
        ((ScoreKeeper.IConvergenceStrategy) Mockito.verify(iConvergenceStrategy)).extremePoint(0.35d, 0.15000000000000002d, 0.25d);
        ((ScoreKeeper.IConvergenceStrategy) Mockito.verify(iConvergenceStrategy)).stopEarly(0.35d, 0.15000000000000002d, 0.25d, 0.01d);
    }

    @Test
    public void testConvergenceStrategy_extremePoint() {
        Assert.assertEquals(0.2d, ScoreKeeper.ConvergenceStrategy.MORE_IS_BETTER.extremePoint(0.0d, 0.1d, 0.2d), 0.0d);
        Assert.assertEquals(0.1d, ScoreKeeper.ConvergenceStrategy.LESS_IS_BETTER.extremePoint(0.0d, 0.1d, 0.2d), 0.0d);
        Assert.assertEquals(0.1d, ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.extremePoint(0.16d, 0.1d, 0.2d), 0.0d);
        Assert.assertEquals(0.2d, ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.extremePoint(0.14d, 0.1d, 0.2d), 0.0d);
    }

    @Test
    public void testConvergenceStrategy_stopEarly() {
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.MORE_IS_BETTER.stopEarly(0.1d, 0.1d, 0.2d, 0.01d));
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.MORE_IS_BETTER.stopEarly(0.1d, 0.1d, 0.2d, 0.1d));
        Assert.assertTrue(ScoreKeeper.ConvergenceStrategy.MORE_IS_BETTER.stopEarly(0.1d, 0.1d, 0.2d, 1.0d));
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.LESS_IS_BETTER.stopEarly(0.2d, 0.1d, 0.3d, 0.01d));
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.LESS_IS_BETTER.stopEarly(0.2d, 0.1d, 0.3d, 0.1d));
        Assert.assertTrue(ScoreKeeper.ConvergenceStrategy.LESS_IS_BETTER.stopEarly(0.2d, 0.1d, 0.3d, 1.0d));
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.stopEarly(0.14d, 0.1d, 0.2d, 0.01d));
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.stopEarly(0.14d, 0.1d, 0.2d, 0.1d));
        Assert.assertTrue(ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.stopEarly(0.14d, 0.1d, 0.2d, 1.0d));
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.stopEarly(0.16d, 0.1d, 0.2d, 0.01d));
        Assert.assertFalse(ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.stopEarly(0.16d, 0.1d, 0.2d, 0.1d));
        Assert.assertTrue(ScoreKeeper.ConvergenceStrategy.NON_DIRECTIONAL.stopEarly(0.16d, 0.1d, 0.2d, 1.0d));
    }
}
