package org.apache.beam.sdk.transforms;

import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.SystemNanoTimeSleeper;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayDataMatchers;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/transforms/IntraBundleParallelizationTest.class */
public class IntraBundleParallelizationTest {
    private static final int PARALLELISM_FACTOR = 16;
    private static final AtomicInteger numSuccesses = new AtomicInteger();
    private static final AtomicInteger numProcessed = new AtomicInteger();
    private static final AtomicInteger numFailures = new AtomicInteger();
    private static int concurrentElements = 0;
    private static int maxDownstreamConcurrency = 0;
    private static final AtomicInteger maxFnConcurrency = new AtomicInteger();
    private static final AtomicInteger currentFnConcurrency = new AtomicInteger();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/transforms/IntraBundleParallelizationTest$ConcurrencyMeasuringFn.class */
    public static class ConcurrencyMeasuringFn<T> extends DoFn<T, T> {
        private ConcurrencyMeasuringFn() {
        }

        public void processElement(DoFn<T, T>.ProcessContext processContext) {
            synchronized (ConcurrencyMeasuringFn.class) {
                IntraBundleParallelizationTest.access$508();
                if (IntraBundleParallelizationTest.concurrentElements > IntraBundleParallelizationTest.maxDownstreamConcurrency) {
                    int unused = IntraBundleParallelizationTest.maxDownstreamConcurrency = IntraBundleParallelizationTest.concurrentElements;
                }
            }
            processContext.output(processContext.element());
            synchronized (ConcurrencyMeasuringFn.class) {
                IntraBundleParallelizationTest.access$510();
            }
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/transforms/IntraBundleParallelizationTest$DelayFn.class */
    private static class DelayFn<T> extends DoFn<T, T> {
        public static final long DELAY_MS = 25;

        private DelayFn() {
        }

        public void processElement(DoFn<T, T>.ProcessContext processContext) {
            IntraBundleParallelizationTest.startConcurrentCall();
            try {
                SystemNanoTimeSleeper.sleepMillis(25L);
                processContext.output(processContext.element());
                IntraBundleParallelizationTest.finishConcurrentCall();
            } catch (InterruptedException e) {
                e.printStackTrace();
                throw new RuntimeException("Interrupted");
            }
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/transforms/IntraBundleParallelizationTest$ExceptionThrowingFn.class */
    private static class ExceptionThrowingFn<T> extends DoFn<T, T> {
        private ExceptionThrowingFn(int i) {
            IntraBundleParallelizationTest.numSuccesses.set(i);
        }

        public void processElement(DoFn<T, T>.ProcessContext processContext) {
            IntraBundleParallelizationTest.startConcurrentCall();
            try {
                IntraBundleParallelizationTest.numProcessed.incrementAndGet();
                if (IntraBundleParallelizationTest.numSuccesses.decrementAndGet() >= 0) {
                    processContext.output(processContext.element());
                } else {
                    IntraBundleParallelizationTest.numFailures.incrementAndGet();
                    throw new RuntimeException("Expected failure");
                }
            } finally {
                IntraBundleParallelizationTest.finishConcurrentCall();
            }
        }
    }

    @Before
    public void setUp() {
        numSuccesses.set(0);
        numProcessed.set(0);
        numFailures.set(0);
        concurrentElements = 0;
        maxDownstreamConcurrency = 0;
        maxFnConcurrency.set(0);
        currentFnConcurrency.set(0);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void startConcurrentCall() {
        int i;
        int incrementAndGet = currentFnConcurrency.incrementAndGet();
        do {
            i = maxFnConcurrency.get();
            if (i >= incrementAndGet) {
                return;
            }
        } while (!maxFnConcurrency.compareAndSet(i, incrementAndGet));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void finishConcurrentCall() {
        currentFnConcurrency.decrementAndGet();
    }

    @Test
    @Category({NeedsRunner.class})
    public void testParallelization() {
        int i = Integer.MIN_VALUE;
        for (int i2 = 0; i2 < 5; i2++) {
            i = Math.max(i, run(32, PARALLELISM_FACTOR, new DelayFn()));
        }
        Assert.assertThat(Integer.valueOf(i), Matchers.greaterThanOrEqualTo(2));
        Assert.assertThat(Integer.valueOf(i), Matchers.lessThanOrEqualTo(Integer.valueOf(PARALLELISM_FACTOR)));
    }

    @Test(timeout = 5000)
    @Category({NeedsRunner.class})
    public void testExceptionHandling() {
        try {
            run(100, PARALLELISM_FACTOR, new ExceptionThrowingFn(10));
            Assert.fail("Expected exception to propagate");
        } catch (RuntimeException e) {
            Assert.assertThat(e.getMessage(), Matchers.containsString("Expected failure"));
        }
        Assert.assertThat(Integer.valueOf(numProcessed.get()), Matchers.is(Matchers.both(Matchers.greaterThanOrEqualTo(10)).and(Matchers.lessThan(100))));
        Assert.assertThat(Integer.valueOf(numFailures.get()), Matchers.is(Matchers.both(Matchers.greaterThanOrEqualTo(1)).and(Matchers.lessThanOrEqualTo(Integer.valueOf(PARALLELISM_FACTOR)))));
    }

    @Test(timeout = 5000)
    @Category({NeedsRunner.class})
    public void testExceptionHandlingOnLastElement() {
        try {
            run(10, PARALLELISM_FACTOR, new ExceptionThrowingFn(9));
            Assert.fail("Expected exception to propagate");
        } catch (RuntimeException e) {
            Assert.assertThat(e.getMessage(), Matchers.containsString("Expected failure"));
        }
        Assert.assertEquals(10L, numProcessed.get());
        Assert.assertEquals(1L, numFailures.get());
    }

    @Test
    public void testIntraBundleParallelizationGetName() {
        Assert.assertEquals("IntraBundleParallelization", IntraBundleParallelization.of(new DelayFn()).withMaxParallelism(1).getName());
    }

    @Test
    public void testDisplayData() {
        DoFn<String, String> doFn = new DoFn<String, String>() { // from class: org.apache.beam.sdk.transforms.IntraBundleParallelizationTest.1
            public void processElement(DoFn<String, String>.ProcessContext processContext) throws Exception {
            }

            public void populateDisplayData(DisplayData.Builder builder) {
                builder.add(DisplayData.item("foo", "bar"));
            }
        };
        DisplayData from = DisplayData.from(IntraBundleParallelization.withMaxParallelism(1234).of(doFn));
        Assert.assertThat(from, DisplayDataMatchers.includesDisplayDataFrom(doFn));
        Assert.assertThat(from, DisplayDataMatchers.hasDisplayItem("fn", doFn.getClass()));
        Assert.assertThat(from, DisplayDataMatchers.hasDisplayItem("maxParallelism", 1234L));
    }

    private int run(int i, int i2, DoFn<Integer, Integer> doFn) {
        TestPipeline create = TestPipeline.create();
        ArrayList arrayList = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            arrayList.add(Integer.valueOf(i3));
        }
        create.apply(Create.of(arrayList)).apply(IntraBundleParallelization.of(doFn).withMaxParallelism(i2)).apply(ParDo.of(new ConcurrencyMeasuringFn()));
        create.run();
        Assert.assertEquals(0L, currentFnConcurrency.get());
        Assert.assertEquals(1L, maxDownstreamConcurrency);
        return maxFnConcurrency.get();
    }

    static /* synthetic */ int access$508() {
        int i = concurrentElements;
        concurrentElements = i + 1;
        return i;
    }

    static /* synthetic */ int access$510() {
        int i = concurrentElements;
        concurrentElements = i - 1;
        return i;
    }
}
