package org.apache.flink.api.java.sampling;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
import org.apache.flink.shaded.com.google.common.collect.Lists;
import org.apache.flink.shaded.com.google.common.collect.Sets;
import org.apache.flink.testutils.junit.RetryOnFailure;
import org.apache.flink.testutils.junit.RetryRule;
import org.apache.flink.util.Preconditions;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/api/java/sampling/RandomSamplerTest.class */
public class RandomSamplerTest {
    private static final int DEFAULT_PARTITION_NUMBER = 10;

    @Rule
    public final RetryRule retryRule = new RetryRule();
    private final List<Double>[] sourcePartitions = new List[DEFAULT_PARTITION_NUMBER];
    private static final KolmogorovSmirnovTest ksTest = new KolmogorovSmirnovTest();
    private static final int SOURCE_SIZE = 10000;
    private static final List<Double> source = new ArrayList(SOURCE_SIZE);

    @BeforeClass
    public static void init() {
        for (int i = 0; i < SOURCE_SIZE; i++) {
            source.add(Double.valueOf(i));
        }
    }

    private void initSourcePartition() {
        for (int i = 0; i < DEFAULT_PARTITION_NUMBER; i++) {
            this.sourcePartitions[i] = new ArrayList((int) Math.ceil(1000.0d));
        }
        for (int i2 = 0; i2 < SOURCE_SIZE; i2++) {
            this.sourcePartitions[i2 % DEFAULT_PARTITION_NUMBER].add(Double.valueOf(i2));
        }
    }

    @Test(expected = IllegalArgumentException.class)
    public void testBernoulliSamplerWithUnexpectedFraction1() {
        verifySamplerFraction(-1.0d, false);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testBernoulliSamplerWithUnexpectedFraction2() {
        verifySamplerFraction(2.0d, false);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testBernoulliSamplerFraction() {
        verifySamplerFraction(0.01d, false);
        verifySamplerFraction(0.05d, false);
        verifySamplerFraction(0.1d, false);
        verifySamplerFraction(0.3d, false);
        verifySamplerFraction(0.5d, false);
        verifySamplerFraction(0.854d, false);
        verifySamplerFraction(0.99d, false);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testBernoulliSamplerDuplicateElements() {
        verifyRandomSamplerDuplicateElements(new BernoulliSampler(0.01d));
        verifyRandomSamplerDuplicateElements(new BernoulliSampler(0.1d));
        verifyRandomSamplerDuplicateElements(new BernoulliSampler(0.5d));
    }

    @Test(expected = IllegalArgumentException.class)
    public void testPoissonSamplerWithUnexpectedFraction1() {
        verifySamplerFraction(-1.0d, true);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testPoissonSamplerFraction() {
        verifySamplerFraction(0.01d, true);
        verifySamplerFraction(0.05d, true);
        verifySamplerFraction(0.1d, true);
        verifySamplerFraction(0.5d, true);
        verifySamplerFraction(0.854d, true);
        verifySamplerFraction(0.99d, true);
        verifySamplerFraction(1.5d, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testReservoirSamplerUnexpectedSize1() {
        verifySamplerFixedSampleSize(-1, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testReservoirSamplerUnexpectedSize2() {
        verifySamplerFixedSampleSize(-1, false);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testBernoulliSamplerDistribution() {
        verifyBernoulliSampler(0.01d);
        verifyBernoulliSampler(0.05d);
        verifyBernoulliSampler(0.1d);
        verifyBernoulliSampler(0.5d);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testPoissonSamplerDistribution() {
        verifyPoissonSampler(0.01d);
        verifyPoissonSampler(0.05d);
        verifyPoissonSampler(0.1d);
        verifyPoissonSampler(0.5d);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testReservoirSamplerSampledSize() {
        verifySamplerFixedSampleSize(1, true);
        verifySamplerFixedSampleSize(DEFAULT_PARTITION_NUMBER, true);
        verifySamplerFixedSampleSize(100, true);
        verifySamplerFixedSampleSize(1234, true);
        verifySamplerFixedSampleSize(9999, true);
        verifySamplerFixedSampleSize(20000, true);
        verifySamplerFixedSampleSize(1, false);
        verifySamplerFixedSampleSize(DEFAULT_PARTITION_NUMBER, false);
        verifySamplerFixedSampleSize(100, false);
        verifySamplerFixedSampleSize(1234, false);
        verifySamplerFixedSampleSize(9999, false);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testReservoirSamplerSampledSize2() {
        Assert.assertTrue("ReservoirSamplerWithoutReplacement sampled output size should not beyond the source size.", getSize(new ReservoirSamplerWithoutReplacement(20000).sample(source.iterator())) == SOURCE_SIZE);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testReservoirSamplerDuplicateElements() {
        verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement(100));
        verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement(1000));
        verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement(5000));
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testReservoirSamplerWithoutReplacement() {
        verifyReservoirSamplerWithoutReplacement(100, false);
        verifyReservoirSamplerWithoutReplacement(500, false);
        verifyReservoirSamplerWithoutReplacement(1000, false);
        verifyReservoirSamplerWithoutReplacement(5000, false);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testReservoirSamplerWithReplacement() {
        verifyReservoirSamplerWithReplacement(100, false);
        verifyReservoirSamplerWithReplacement(500, false);
        verifyReservoirSamplerWithReplacement(1000, false);
        verifyReservoirSamplerWithReplacement(5000, false);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testReservoirSamplerWithMultiSourcePartitions1() {
        initSourcePartition();
        verifyReservoirSamplerWithoutReplacement(100, true);
        verifyReservoirSamplerWithoutReplacement(500, true);
        verifyReservoirSamplerWithoutReplacement(1000, true);
        verifyReservoirSamplerWithoutReplacement(5000, true);
    }

    @Test
    @RetryOnFailure(times = 3)
    public void testReservoirSamplerWithMultiSourcePartitions2() {
        initSourcePartition();
        verifyReservoirSamplerWithReplacement(100, true);
        verifyReservoirSamplerWithReplacement(500, true);
        verifyReservoirSamplerWithReplacement(1000, true);
        verifyReservoirSamplerWithReplacement(5000, true);
    }

    private void verifySamplerFixedSampleSize(int i, boolean z) {
        Assert.assertEquals(i, getSize((z ? new ReservoirSamplerWithReplacement(i) : new ReservoirSamplerWithoutReplacement(i)).sample(source.iterator())));
    }

    private void verifySamplerFraction(double d, boolean z) {
        PoissonSampler poissonSampler = z ? new PoissonSampler(d) : new BernoulliSampler(d);
        int i = 0;
        for (int i2 = 0; i2 < 20.0d; i2++) {
            i += getSize(poissonSampler.sample(source.iterator()));
        }
        double d2 = i / (10000.0d * 20.0d);
        Assert.assertTrue(String.format("expected fraction: %f, result fraction: %f", Double.valueOf(d), Double.valueOf(d2)), Math.abs((d2 - d) / d) < 0.2d);
    }

    private void verifyRandomSamplerDuplicateElements(final RandomSampler<Double> randomSampler) {
        LinkedList newLinkedList = Lists.newLinkedList(new Iterable<Double>() { // from class: org.apache.flink.api.java.sampling.RandomSamplerTest.1
            @Override // java.lang.Iterable
            public Iterator<Double> iterator() {
                return randomSampler.sample(RandomSamplerTest.source.iterator());
            }
        });
        Assert.assertTrue("There should not have duplicate element for sampler without replacement.", newLinkedList.size() == Sets.newHashSet(newLinkedList).size());
    }

    private int getSize(Iterator<?> it) {
        int i = 0;
        while (it.hasNext()) {
            it.next();
            i++;
        }
        return i;
    }

    private void verifyBernoulliSampler(double d) {
        BernoulliSampler bernoulliSampler = new BernoulliSampler(d);
        verifyRandomSamplerWithFraction(d, bernoulliSampler, true);
        verifyRandomSamplerWithFraction(d, bernoulliSampler, false);
    }

    private void verifyPoissonSampler(double d) {
        PoissonSampler poissonSampler = new PoissonSampler(d);
        verifyRandomSamplerWithFraction(d, poissonSampler, true);
        verifyRandomSamplerWithFraction(d, poissonSampler, false);
    }

    private void verifyReservoirSamplerWithReplacement(int i, boolean z) {
        ReservoirSamplerWithReplacement reservoirSamplerWithReplacement = new ReservoirSamplerWithReplacement(i);
        verifyRandomSamplerWithSampleSize(i, reservoirSamplerWithReplacement, true, z);
        verifyRandomSamplerWithSampleSize(i, reservoirSamplerWithReplacement, false, z);
    }

    private void verifyReservoirSamplerWithoutReplacement(int i, boolean z) {
        ReservoirSamplerWithoutReplacement reservoirSamplerWithoutReplacement = new ReservoirSamplerWithoutReplacement(i);
        verifyRandomSamplerWithSampleSize(i, reservoirSamplerWithoutReplacement, true, z);
        verifyRandomSamplerWithSampleSize(i, reservoirSamplerWithoutReplacement, false, z);
    }

    private void verifyRandomSamplerWithFraction(double d, RandomSampler<Double> randomSampler, boolean z) {
        verifyKSTest(randomSampler, z ? getDefaultSampler(d) : getWrongSampler(d), z);
    }

    private void verifyRandomSamplerWithSampleSize(int i, RandomSampler<Double> randomSampler, boolean z, boolean z2) {
        verifyKSTest(randomSampler, z ? getDefaultSampler(i) : getWrongSampler(i), z, z2);
    }

    private void verifyKSTest(RandomSampler<Double> randomSampler, double[] dArr, boolean z) {
        verifyKSTest(randomSampler, dArr, z, false);
    }

    private void verifyKSTest(RandomSampler<Double> randomSampler, double[] dArr, boolean z, boolean z2) {
        double[] sampledOutput = getSampledOutput(randomSampler, z2);
        double kolmogorovSmirnovStatistic = ksTest.kolmogorovSmirnovStatistic(sampledOutput, dArr);
        double dValue = getDValue(sampledOutput.length, dArr.length);
        if (z) {
            Assert.assertTrue(String.format("KS test result with p value(%f), d value(%f)", Double.valueOf(kolmogorovSmirnovStatistic), Double.valueOf(dValue)), kolmogorovSmirnovStatistic <= dValue);
        } else {
            Assert.assertTrue(String.format("KS test result with p value(%f), d value(%f)", Double.valueOf(kolmogorovSmirnovStatistic), Double.valueOf(dValue)), kolmogorovSmirnovStatistic > dValue);
        }
    }

    private double[] getSampledOutput(RandomSampler<Double> randomSampler, boolean z) {
        Iterator sample;
        if (z) {
            DistributedRandomSampler distributedRandomSampler = (DistributedRandomSampler) randomSampler;
            LinkedList newLinkedList = Lists.newLinkedList();
            for (int i = 0; i < DEFAULT_PARTITION_NUMBER; i++) {
                Iterator sampleInPartition = distributedRandomSampler.sampleInPartition(this.sourcePartitions[i].iterator());
                while (sampleInPartition.hasNext()) {
                    newLinkedList.add(sampleInPartition.next());
                }
            }
            sample = distributedRandomSampler.sampleInCoordinator(newLinkedList.iterator());
        } else {
            sample = randomSampler.sample(source.iterator());
        }
        ArrayList newArrayList = Lists.newArrayList();
        while (sample.hasNext()) {
            newArrayList.add(sample.next());
        }
        return transferFromListToArrayWithOrder(newArrayList);
    }

    private double[] transferFromListToArrayWithOrder(List<Double> list) {
        Collections.sort(list);
        double[] dArr = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = list.get(i).doubleValue();
        }
        return dArr;
    }

    private double[] getDefaultSampler(double d) {
        Preconditions.checkArgument(d > 0.0d, "Sample fraction should be positive.");
        int i = (int) (10000.0d * d);
        double d2 = 1.0d / d;
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = Math.round(d2 * i2);
        }
        return dArr;
    }

    private double[] getDefaultSampler(int i) {
        Preconditions.checkArgument(i > 0, "Sample fraction should be positive.");
        double d = 10000.0d / i;
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = Math.round(d * i2);
        }
        return dArr;
    }

    private double[] getWrongSampler(double d) {
        Preconditions.checkArgument(d > 0.0d, "Sample size should be positive.");
        int i = (int) (10000.0d * d);
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = i2 % 5000;
        }
        return dArr;
    }

    private double[] getWrongSampler(int i) {
        Preconditions.checkArgument(i > 0, "Sample size be positive.");
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = i2 % 5000;
        }
        return dArr;
    }

    private double getDValue(int i, int i2) {
        Preconditions.checkArgument(i > 0, "input sample size should be positive.");
        Preconditions.checkArgument(i2 > 0, "input sample size should be positive.");
        double d = i;
        double d2 = i2;
        return 1.95d * Math.sqrt((d + d2) / (d * d2));
    }
}
