package org.apache.commons.rng.sampling.distribution;

import java.util.Arrays;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.ExponentialDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.stat.inference.ChiSquareTest;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.core.source64.SplitMix64;
import org.apache.commons.rng.sampling.RandomAssert;
import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

/* loaded from: input_file:org/apache/commons/rng/sampling/distribution/ZigguratSamplerTest.class */
class ZigguratSamplerTest {
    private static final Long SEED = 11259375L;

    ZigguratSamplerTest() {
    }

    @Test
    void testExponentialConstructorThrowsWithZeroMean() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            ZigguratSampler.Exponential.of(create, 0.0d);
        });
    }

    @Test
    void testExponentialSharedStateSampler() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        RestorableUniformRandomProvider create2 = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        ZigguratSampler.Exponential of = ZigguratSampler.Exponential.of(create);
        RandomAssert.assertProduceSameSequence((ContinuousSampler) of, (ContinuousSampler) of.withUniformRandomProvider(create2));
    }

    @Test
    void testExponentialSharedStateSamplerWithMean() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        RestorableUniformRandomProvider create2 = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        ZigguratSampler.Exponential of = ZigguratSampler.Exponential.of(create, 1.23d);
        RandomAssert.assertProduceSameSequence((ContinuousSampler) of, (ContinuousSampler) of.withUniformRandomProvider(create2));
    }

    @Test
    void testGaussianSharedStateSampler() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        RestorableUniformRandomProvider create2 = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        ZigguratSampler.NormalizedGaussian of = ZigguratSampler.NormalizedGaussian.of(create);
        RandomAssert.assertProduceSameSequence((ContinuousSampler) of, (ContinuousSampler) of.withUniformRandomProvider(create2));
    }

    @Test
    void testExponentialRecursion() {
        Assertions.assertEquals(0.0d, expSample(42L, 0));
        Assertions.assertEquals(7.569274694148063d, expSample(42L, -1, -1, 0));
        for (long j : new long[]{42, -2136612838, 2340923842L, -1263746817818681L}) {
            double expSample = expSample(j, new long[0]);
            double expSample2 = expSample(j, -1);
            Assertions.assertEquals(expSample + 7.569274694148063d, expSample(j, -1, -1));
            Assertions.assertEquals(expSample2 + 7.569274694148063d, expSample(j, -1, -1, -1));
            Assertions.assertEquals(15.138549388296125d + expSample, expSample(j, -1, -1, -1, -1));
            Assertions.assertEquals(15.138549388296125d + expSample2, expSample(j, -1, -1, -1, -1, -1));
        }
    }

    private static double expSample(long j, final long... jArr) {
        return ZigguratSampler.Exponential.of(new SplitMix64(j) { // from class: org.apache.commons.rng.sampling.distribution.ZigguratSamplerTest.1
            private int i;

            public long next() {
                if (this.i == jArr.length) {
                    return super.next();
                }
                long[] jArr2 = jArr;
                int i = this.i;
                this.i = i + 1;
                return jArr2[i];
            }
        }).sample();
    }

    private static Stream<Arguments> gaussianSamplers() {
        return Stream.of(Arguments.of(new Object[]{"NormalizedGaussian", ZigguratSampler.NormalizedGaussian::of}));
    }

    private static Stream<Arguments> exponentialSamplers() {
        return Stream.of(Arguments.of(new Object[]{"Exponential", ZigguratSampler.Exponential::of}));
    }

    private static AbstractRealDistribution createGaussianDistribution() {
        return new NormalDistribution((RandomGenerator) null, 0.0d, 1.0d);
    }

    private static AbstractRealDistribution createExponentialDistribution() {
        return new ExponentialDistribution((RandomGenerator) null, 1.0d);
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    @MethodSource({"gaussianSamplers"})
    @ParameterizedTest(name = "{index} => {0}")
    void testGaussianSamplesWithQuantiles(String str, Function<UniformRandomProvider, ContinuousSampler> function) {
        AbstractRealDistribution createGaussianDistribution = createGaussianDistribution();
        double[] dArr = new double[2000];
        for (int i = 0; i < 2000; i++) {
            dArr[i] = createGaussianDistribution.inverseCumulativeProbability((i + 1.0d) / 2000.0d);
        }
        testSamples(dArr, function, ZigguratSamplerTest::createGaussianDistribution, new double[]{new double[]{0.0d, 0.2d}, new double[]{-0.35d, -0.1d}, new double[]{-0.1d, 0.1d}, new double[]{-0.4d, 0.6d}, new double[]{-1.1d, -0.9d}, new double[]{2.1d, 2.5d}, new double[]{2.5d, 8.0d}});
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    @MethodSource({"gaussianSamplers"})
    @ParameterizedTest(name = "{index} => {0}")
    void testGaussianSamplesWithUniformValues(String str, Function<UniformRandomProvider, ContinuousSampler> function) {
        double[] dArr = new double[2000];
        for (int i = 0; i < 2000; i++) {
            dArr[i] = (-8.0d) + ((16.0d * (i + 1.0d)) / 2000.0d);
        }
        dArr[1999] = Double.POSITIVE_INFINITY;
        testSamples(dArr, function, ZigguratSamplerTest::createGaussianDistribution, new double[]{new double[]{0.0d, 0.2d}, new double[]{-0.35d, -0.1d}, new double[]{-0.1d, 0.1d}, new double[]{-0.4d, 0.6d}, new double[]{-1.01d, -0.99d}, new double[]{0.98d, 1.03d}, new double[]{1.03d, 1.05d}, new double[]{3.6d, 3.8d}, new double[]{3.7d, 8.0d}});
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    @MethodSource({"exponentialSamplers"})
    @ParameterizedTest(name = "{index} => {0}")
    void testExponentialSamplesWithQuantiles(String str, Function<UniformRandomProvider, ContinuousSampler> function) {
        AbstractRealDistribution createExponentialDistribution = createExponentialDistribution();
        double[] dArr = new double[2000];
        for (int i = 0; i < 2000; i++) {
            dArr[i] = createExponentialDistribution.inverseCumulativeProbability((i + 1.0d) / 2000.0d);
        }
        testSamples(dArr, function, ZigguratSamplerTest::createExponentialDistribution, new double[]{new double[]{0.0d, 0.1d}, new double[]{0.05d, 0.15d}, new double[]{0.9d, 1.1d}, new double[]{1.5d, 12.0d}});
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    @MethodSource({"exponentialSamplers"})
    @ParameterizedTest(name = "{index} => {0}")
    void testExponentialSamplesWithUniformValues(String str, Function<UniformRandomProvider, ContinuousSampler> function) {
        double[] dArr = new double[2000];
        for (int i = 0; i < 2000; i++) {
            dArr[i] = 0.0d + ((12.0d * (i + 1.0d)) / 2000.0d);
        }
        dArr[1999] = Double.POSITIVE_INFINITY;
        testSamples(dArr, function, ZigguratSamplerTest::createExponentialDistribution, new double[]{new double[]{0.0d, 0.1d}, new double[]{0.05d, 0.15d}, new double[]{0.9d, 1.1d}, new double[]{7.5d, 7.7d}, new double[]{7.7d, 12.0d}});
    }

    private static void testSamples(double[] dArr, Function<UniformRandomProvider, ContinuousSampler> function, Supplier<AbstractRealDistribution> supplier, double[]... dArr2) {
        int length = dArr.length;
        long[] jArr = new long[length];
        ContinuousSampler apply = function.apply(RandomSource.XO_SHI_RO_128_PP.create(SEED, new Object[0]));
        for (int i = 0; i < 10000000; i++) {
            int findIndex = findIndex(dArr, apply.sample());
            jArr[findIndex] = jArr[findIndex] + 1;
        }
        AbstractRealDistribution abstractRealDistribution = supplier.get();
        double[] dArr3 = new double[length];
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < length; i2++) {
            double d2 = dArr[i2];
            dArr3[i2] = abstractRealDistribution.probability(d, d2);
            d = d2;
        }
        double supportLowerBound = abstractRealDistribution.getSupportLowerBound();
        ChiSquareTest chiSquareTest = new ChiSquareTest();
        double chiSquareTest2 = chiSquareTest.chiSquareTest(dArr3, jArr);
        Assertions.assertFalse(chiSquareTest2 < 0.001d, () -> {
            return String.format("(%s <= x < %s) Chi-square p-value = %s", Double.valueOf(supportLowerBound), Double.valueOf(dArr[length - 1]), Double.valueOf(chiSquareTest2));
        });
        for (double[] dArr4 : dArr2) {
            int findIndex2 = findIndex(dArr, dArr4[0]);
            int findIndex3 = findIndex(dArr, dArr4[1]);
            if ((findIndex3 - findIndex2) + 1 < 2) {
                Assertions.fail("Invalid range: " + Arrays.toString(dArr4));
            }
            double chiSquareTest3 = chiSquareTest.chiSquareTest(Arrays.copyOfRange(dArr3, findIndex2, findIndex3 + 1), Arrays.copyOfRange(jArr, findIndex2, findIndex3 + 1));
            Assertions.assertFalse(chiSquareTest3 < 0.001d, () -> {
                Object[] objArr = new Object[3];
                objArr[0] = Double.valueOf(findIndex2 == 0 ? supportLowerBound : dArr[findIndex2 - 1]);
                objArr[1] = Double.valueOf(dArr[findIndex3]);
                objArr[2] = Double.valueOf(chiSquareTest3);
                return String.format("(%s <= x < %s) Chi-square p-value = %s", objArr);
            });
        }
    }

    private static int findIndex(double[] dArr, double d) {
        int i = 0;
        int length = dArr.length - 1;
        while (i <= length) {
            int i2 = (i + length) >>> 1;
            if (d < dArr[i2]) {
                length = i2 - 1;
            } else {
                i = i2 + 1;
            }
        }
        Assertions.assertTrue(d < dArr[i]);
        if (i != 0) {
            Assertions.assertTrue(d >= dArr[i - 1]);
        }
        return i;
    }
}
