package org.apache.commons.rng.sampling;

import java.util.Arrays;
import org.apache.commons.math3.stat.inference.ChiSquareTest;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.core.source64.SplitMix64;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/commons/rng/sampling/UnitSphereSamplerTest.class */
class UnitSphereSamplerTest {
    private static final double TWO_PI = 6.283185307179586d;

    UnitSphereSamplerTest() {
    }

    @Test
    void testInvalidDimensionThrows() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new UnitSphereSampler(0, (UniformRandomProvider) null);
        });
    }

    @Test
    void testInvalidDimensionThrowsWithFactoryConstructor() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            UnitSphereSampler.of((UniformRandomProvider) null, 0);
        });
    }

    @Test
    void testDistribution1D() {
        testDistribution1D(false);
    }

    @Test
    void testDistribution1DWithFactoryConstructor() {
        testDistribution1D(true);
    }

    private static void testDistribution1D(boolean z) {
        UnitSphereSampler createUnitSphereSampler = createUnitSphereSampler(1, RandomSource.XO_RO_SHI_RO_128_PP.create(1715004L, new Object[0]), z);
        int i = 0;
        for (int i2 = 0; i2 < 10000; i2++) {
            double[] nextVector = createUnitSphereSampler.nextVector();
            Assertions.assertEquals(1, nextVector.length);
            double d = nextVector[0];
            if (d == -1.0d) {
                i++;
            } else if (d != 1.0d) {
                Assertions.fail("Invalid unit length: " + d);
            }
        }
        assertMonobit(i, 10000);
    }

    private static void assertMonobit(int i, int i2) {
        double abs = Math.abs((2.0d * i) - i2);
        double sqrt = Math.sqrt(i2) * 2.576d;
        Assertions.assertTrue(abs <= sqrt, () -> {
            return "Walked too far astray: " + abs + " > " + sqrt + " (test will fail randomly about 1 in 100 times)";
        });
    }

    @Test
    void testDistribution2D() {
        testDistribution2D(false);
    }

    @Test
    void testDistribution2DWithFactoryConstructor() {
        testDistribution2D(true);
    }

    private static void testDistribution2D(boolean z) {
        UnitSphereSampler createUnitSphereSampler = createUnitSphereSampler(2, RandomSource.XOR_SHIFT_1024_S_PHI.create(17399225432L, new Object[0]), z);
        long[] jArr = new long[200];
        for (int i = 0; i < 100000; i++) {
            double[] sample = createUnitSphereSampler.sample();
            Assertions.assertEquals(2, sample.length);
            Assertions.assertEquals(1.0d, length(sample), 1.0E-10d);
            int angleBin = angleBin(200, sample[0], sample[1]);
            jArr[angleBin] = jArr[angleBin] + 1;
        }
        double[] dArr = new double[jArr.length];
        Arrays.fill(dArr, 100000.0d / jArr.length);
        double chiSquareTest = new ChiSquareTest().chiSquareTest(dArr, jArr);
        Assertions.assertFalse(chiSquareTest < 0.01d, () -> {
            return "p-value too small: " + chiSquareTest;
        });
    }

    @Test
    void testDistribution3D() {
        testDistribution3D(false);
    }

    @Test
    void testDistribution3DWithFactoryConstructor() {
        testDistribution3D(true);
    }

    private static void testDistribution3D(boolean z) {
        UnitSphereSampler createUnitSphereSampler = createUnitSphereSampler(3, RandomSource.XO_SHI_RO_256_PP.create(11259375L, new Object[0]), z);
        long[] jArr = new long[200];
        for (int i = 0; i < 1000000; i++) {
            double[] sample = createUnitSphereSampler.sample();
            Assertions.assertEquals(3, sample.length);
            Assertions.assertEquals(1.0d, length(sample), 1.0E-10d);
            int angleBin = (((int) ((10.0d * (sample[2] + 1.0d)) / 2.0d)) * 20) + angleBin(20, sample[0], sample[1]);
            jArr[angleBin] = jArr[angleBin] + 1;
        }
        double[] dArr = new double[jArr.length];
        Arrays.fill(dArr, 1000000.0d / jArr.length);
        double chiSquareTest = new ChiSquareTest().chiSquareTest(dArr, jArr);
        Assertions.assertFalse(chiSquareTest < 0.01d, () -> {
            return "p-value too small: " + chiSquareTest;
        });
    }

    @Test
    void testDistribution4D() {
        testDistribution4D(false);
    }

    @Test
    void testDistribution4DWithFactoryConstructor() {
        testDistribution4D(true);
    }

    private static void testDistribution4D(boolean z) {
        UnitSphereSampler createUnitSphereSampler = createUnitSphereSampler(4, RandomSource.XO_SHI_RO_512_PP.create(654820258320L, new Object[0]), z);
        double[] dArr = new double[10];
        for (int i = 1; i < 10; i++) {
            dArr[i - 1] = Math.sqrt(i / 10.0d);
        }
        dArr[9] = 1.0d;
        long[] jArr = new long[200];
        long[] jArr2 = new long[jArr.length];
        for (int i2 = 0; i2 < 1000000; i2++) {
            double[] sample = createUnitSphereSampler.sample();
            Assertions.assertEquals(4, sample.length);
            Assertions.assertEquals(1.0d, length(sample), 1.0E-10d);
            int circleBin = circleBin(20, dArr, sample[0], sample[1]);
            jArr[circleBin] = jArr[circleBin] + 1;
            int circleBin2 = circleBin(20, dArr, sample[2], sample[3]);
            jArr2[circleBin2] = jArr2[circleBin2] + 1;
        }
        double[] dArr2 = new double[jArr.length];
        Arrays.fill(dArr2, 1000000.0d / jArr.length);
        ChiSquareTest chiSquareTest = new ChiSquareTest();
        double chiSquareTest2 = chiSquareTest.chiSquareTest(dArr2, jArr);
        Assertions.assertFalse(chiSquareTest2 < 0.01d, () -> {
            return "Circle 1 p-value too small: " + chiSquareTest2;
        });
        double chiSquareTest3 = chiSquareTest.chiSquareTest(dArr2, jArr2);
        Assertions.assertFalse(chiSquareTest3 < 0.01d, () -> {
            return "Circle 2 p-value too small: " + chiSquareTest3;
        });
    }

    private static int circleBin(int i, double[] dArr, double d, double d2) {
        return (radiusBin(dArr, d, d2) * i) + angleBin(i, d, d2);
    }

    private static int angleBin(int i, double d, double d2) {
        return (int) ((i * (Math.atan2(d2, d) + 3.141592653589793d)) / TWO_PI);
    }

    private static int radiusBin(double[] dArr, double d, double d2) {
        double sqrt = Math.sqrt((d * d) + (d2 * d2));
        for (int i = 0; i < dArr.length; i++) {
            if (sqrt <= dArr[i]) {
                return i;
            }
        }
        throw new AssertionError("Invalid sample length: " + sqrt);
    }

    @Test
    void testBadProvider2D() {
        Assertions.assertThrows(StackOverflowError.class, () -> {
            testBadProvider(2);
        });
    }

    @Test
    void testBadProvider3D() {
        Assertions.assertThrows(StackOverflowError.class, () -> {
            testBadProvider(3);
        });
    }

    @Test
    void testBadProvider4D() {
        Assertions.assertThrows(StackOverflowError.class, () -> {
            testBadProvider(4);
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void testBadProvider(int i) {
        UnitSphereSampler.of(new SplitMix64(0L) { // from class: org.apache.commons.rng.sampling.UnitSphereSamplerTest.1
            public long nextLong() {
                return 0L;
            }
        }, i).sample();
    }

    @Test
    void testInvalidInverseNormalisation2D() {
        testInvalidInverseNormalisationND(2);
    }

    @Test
    void testInvalidInverseNormalisation3D() {
        testInvalidInverseNormalisationND(3);
    }

    @Test
    void testInvalidInverseNormalisation4D() {
        testInvalidInverseNormalisationND(4);
    }

    private static void testInvalidInverseNormalisationND(final int i) {
        double[] sample = UnitSphereSampler.of(new SplitMix64(1715004L) { // from class: org.apache.commons.rng.sampling.UnitSphereSamplerTest.2
            private int count;

            {
                this.count = (-2) * i;
            }

            public long nextLong() {
                int i2 = this.count;
                this.count = i2 + 1;
                if (i2 < 0) {
                    return 0L;
                }
                return super.nextLong();
            }
        }, i).sample();
        Assertions.assertEquals(i, sample.length);
        Assertions.assertEquals(1.0d, length(sample), 1.0E-10d);
    }

    @Test
    void testNextNormSquaredAfterZeroIsValid() {
        double sqrt = 1.0d / Math.sqrt(Math.nextUp(0.0d));
        Assertions.assertTrue(sqrt > 0.0d && sqrt <= Double.MAX_VALUE);
    }

    @Test
    void testSharedStateSampler1D() {
        testSharedStateSampler(1, false);
    }

    @Test
    void testSharedStateSampler2D() {
        testSharedStateSampler(2, false);
    }

    @Test
    void testSharedStateSampler3D() {
        testSharedStateSampler(3, false);
    }

    @Test
    void testSharedStateSampler4D() {
        testSharedStateSampler(4, false);
    }

    @Test
    void testSharedStateSampler1DWithFactoryConstructor() {
        testSharedStateSampler(1, true);
    }

    @Test
    void testSharedStateSampler2DWithFactoryConstructor() {
        testSharedStateSampler(2, true);
    }

    @Test
    void testSharedStateSampler3DWithFactoryConstructor() {
        testSharedStateSampler(3, true);
    }

    @Test
    void testSharedStateSampler4DWithFactoryConstructor() {
        testSharedStateSampler(4, true);
    }

    private static void testSharedStateSampler(int i, boolean z) {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        RestorableUniformRandomProvider create2 = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        UnitSphereSampler createUnitSphereSampler = createUnitSphereSampler(i, create, z);
        RandomAssert.assertProduceSameSequence((ObjectSampler) createUnitSphereSampler, (ObjectSampler) createUnitSphereSampler.withUniformRandomProvider(create2));
    }

    private static UnitSphereSampler createUnitSphereSampler(int i, UniformRandomProvider uniformRandomProvider, boolean z) {
        return z ? UnitSphereSampler.of(uniformRandomProvider, i) : new UnitSphereSampler(i, uniformRandomProvider);
    }

    private static double length(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2 * d2;
        }
        return Math.sqrt(d);
    }
}
