package org.apache.commons.rng.sampling.distribution;

import java.util.Arrays;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.core.source64.SplitMix64;
import org.apache.commons.rng.sampling.ObjectSampler;
import org.apache.commons.rng.sampling.RandomAssert;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/commons/rng/sampling/distribution/DirichletSamplerTest.class */
class DirichletSamplerTest {
    DirichletSamplerTest() {
    }

    @Test
    void testDistributionThrowsWithInvalidNumberOfCategories() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.of(create, new double[]{1.0d});
        });
    }

    @Test
    void testDistributionThrowsWithZeroConcentration() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.of(create, new double[]{1.0d, 0.0d});
        });
    }

    @Test
    void testDistributionThrowsWithNaNConcentration() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.of(create, new double[]{1.0d, Double.NaN});
        });
    }

    @Test
    void testDistributionThrowsWithInfiniteConcentration() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.of(create, new double[]{1.0d, Double.POSITIVE_INFINITY});
        });
    }

    @Test
    void testSymmetricDistributionThrowsWithInvalidNumberOfCategories() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.symmetric(create, 1, 1.0d);
        });
    }

    @Test
    void testSymmetricDistributionThrowsWithZeroConcentration() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.symmetric(create, 2, 0.0d);
        });
    }

    @Test
    void testSymmetricDistributionThrowsWithNaNConcentration() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.symmetric(create, 2, Double.NaN);
        });
    }

    @Test
    void testSymmetricDistributionThrowsWithInfiniteConcentration() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            DirichletSampler.symmetric(create, 2, Double.POSITIVE_INFINITY);
        });
    }

    @Test
    void testInvalidSampleIsIgnored() {
        assertSample(2, DirichletSampler.symmetric(new SplitMix64(0L) { // from class: org.apache.commons.rng.sampling.distribution.DirichletSamplerTest.1
            private int i;

            public long next() {
                int i = this.i;
                this.i = i + 1;
                if (i < 10) {
                    return 0L;
                }
                return super.next();
            }
        }, 2, 1.0d).sample());
    }

    @Test
    void testSharedStateSampler() {
        RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
        byte[] createSeed = randomSource.createSeed();
        RestorableUniformRandomProvider create = randomSource.create(createSeed, new Object[0]);
        RestorableUniformRandomProvider create2 = randomSource.create(createSeed, new Object[0]);
        DirichletSampler of = DirichletSampler.of(create, new double[]{1.0d, 2.0d, 3.0d});
        RandomAssert.assertProduceSameSequence((ObjectSampler) of, (ObjectSampler) of.withUniformRandomProvider(create2));
    }

    @Test
    void testSharedStateSamplerForSymmetricCase() {
        RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
        byte[] createSeed = randomSource.createSeed();
        RestorableUniformRandomProvider create = randomSource.create(createSeed, new Object[0]);
        RestorableUniformRandomProvider create2 = randomSource.create(createSeed, new Object[0]);
        DirichletSampler symmetric = DirichletSampler.symmetric(create, 2, 1.5d);
        RandomAssert.assertProduceSameSequence((ObjectSampler) symmetric, (ObjectSampler) symmetric.withUniformRandomProvider(create2));
    }

    @Test
    void testSymmetricCaseMatchesGeneralCase() {
        RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
        byte[] createSeed = randomSource.createSeed();
        RestorableUniformRandomProvider create = randomSource.create(createSeed, new Object[0]);
        RestorableUniformRandomProvider create2 = randomSource.create(createSeed, new Object[0]);
        double[] dArr = new double[3];
        for (double d : new double[]{0.5d, 1.0d, 1.5d}) {
            Arrays.fill(dArr, d);
            RandomAssert.assertProduceSameSequence((ObjectSampler) DirichletSampler.symmetric(create, 3, d), (ObjectSampler) DirichletSampler.of(create2, dArr));
        }
    }

    @Test
    void testToString() {
        RestorableUniformRandomProvider create = RandomSource.SPLIT_MIX_64.create(0L, new Object[0]);
        DirichletSampler symmetric = DirichletSampler.symmetric(create, 2, 1.0d);
        DirichletSampler of = DirichletSampler.of(create, new double[]{0.5d, 1.0d, 1.5d});
        Assertions.assertTrue(symmetric.toString().toLowerCase().contains("dirichlet"));
        Assertions.assertTrue(of.toString().toLowerCase().contains("dirichlet"));
    }

    @Test
    void testSampling1() {
        assertSamples(1.0d, 2.0d, 3.0d);
    }

    @Test
    void testSampling2() {
        assertSamples(1.0d, 1.0d, 1.0d);
    }

    @Test
    void testSampling3() {
        assertSamples(0.5d, 1.0d, 1.5d);
    }

    @Test
    void testSampling4() {
        assertSamples(1.0d, 3.0d);
    }

    @Test
    void testSampling5() {
        assertSamples(1.0d, 2.0d, 3.0d, 4.0d);
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    private static void assertSamples(double... dArr) {
        DirichletSampler of = DirichletSampler.of(RandomSource.XO_RO_SHI_RO_128_PP.create(), dArr);
        int length = dArr.length;
        ?? r0 = new double[100000];
        for (int i = 0; i < r0.length; i++) {
            double[] sample = of.sample();
            assertSample(length, sample);
            r0[i] = sample;
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        double[] columnMeans = getColumnMeans(r0);
        for (int i2 = 0; i2 < length; i2++) {
            double d3 = dArr[i2] / d;
            Assertions.assertEquals(d3, columnMeans[i2], d3 * 0.05d, "Mean");
        }
        double[][] covariance = getCovariance(r0);
        double d4 = d * d * (d + 1.0d);
        for (int i3 = 0; i3 < length; i3++) {
            double d5 = (dArr[i3] * (d - dArr[i3])) / d4;
            Assertions.assertEquals(d5, covariance[i3][i3], d5 * 0.05d, "Variance");
            for (int i4 = i3 + 1; i4 < length; i4++) {
                double d6 = ((-dArr[i3]) * dArr[i4]) / d4;
                Assertions.assertEquals(d6, covariance[i3][i4], Math.abs(d6) * 0.05d, "Covariance");
            }
        }
    }

    private static void assertSample(int i, double[] dArr) {
        Assertions.assertEquals(i, dArr.length, "Number of categories");
        double d = dArr[0] + dArr[1];
        for (int i2 = 2; i2 < dArr.length; i2++) {
            d += dArr[i2];
        }
        Assertions.assertEquals(1.0d, d, 1.0E-10d, "Invalid sum");
    }

    private static double[] getColumnMeans(double[][] dArr) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(dArr, false);
        Mean mean = new Mean();
        double[] dArr2 = new double[array2DRowRealMatrix.getColumnDimension()];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = mean.evaluate(array2DRowRealMatrix.getColumn(i));
        }
        return dArr2;
    }

    private static double[][] getCovariance(double[][] dArr) {
        return new Covariance(new Array2DRowRealMatrix(dArr, false)).getCovarianceMatrix().getData();
    }
}
