package org.apache.beam.sdk.transforms;

import com.google.common.collect.testing.SampleElements;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesParDoLifecycle;
import org.apache.beam.sdk.testing.UsesStatefulParDo;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest.class */
public class ParDoLifecycleTest implements Serializable {

    @Rule
    public final transient TestPipeline p = TestPipeline.create();

    /* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest$CallSequenceEnforcingFn.class */
    private static class CallSequenceEnforcingFn<T> extends DoFn<T, T> {
        private boolean setupCalled;
        private int startBundleCalls;
        private int finishBundleCalls;
        private boolean teardownCalled;

        private CallSequenceEnforcingFn() {
            this.setupCalled = false;
            this.startBundleCalls = 0;
            this.finishBundleCalls = 0;
            this.teardownCalled = false;
        }

        @DoFn.Setup
        public void before() {
            MatcherAssert.assertThat("setup should not be called twice", Boolean.valueOf(this.setupCalled), Matchers.is(false));
            MatcherAssert.assertThat("setup should be called before startBundle", Integer.valueOf(this.startBundleCalls), Matchers.equalTo(0));
            MatcherAssert.assertThat("setup should be called before finishBundle", Integer.valueOf(this.finishBundleCalls), Matchers.equalTo(0));
            MatcherAssert.assertThat("setup should be called before teardown", Boolean.valueOf(this.teardownCalled), Matchers.is(false));
            this.setupCalled = true;
        }

        @DoFn.StartBundle
        public void begin() {
            MatcherAssert.assertThat("setup should have been called", Boolean.valueOf(this.setupCalled), Matchers.is(true));
            MatcherAssert.assertThat("Even number of startBundle and finishBundle calls in startBundle", Integer.valueOf(this.startBundleCalls), Matchers.equalTo(Integer.valueOf(this.finishBundleCalls)));
            MatcherAssert.assertThat("teardown should not have been called", Boolean.valueOf(this.teardownCalled), Matchers.is(false));
            this.startBundleCalls++;
        }

        @DoFn.ProcessElement
        public void process(DoFn<T, T>.ProcessContext processContext) throws Exception {
            MatcherAssert.assertThat("startBundle should have been called", Integer.valueOf(this.startBundleCalls), Matchers.greaterThan(0));
            MatcherAssert.assertThat("there should be one startBundle call with no call to finishBundle", Integer.valueOf(this.startBundleCalls), Matchers.equalTo(Integer.valueOf(this.finishBundleCalls + 1)));
            MatcherAssert.assertThat("teardown should not have been called", Boolean.valueOf(this.teardownCalled), Matchers.is(false));
        }

        @DoFn.FinishBundle
        public void end() {
            MatcherAssert.assertThat("startBundle should have been called", Integer.valueOf(this.startBundleCalls), Matchers.greaterThan(0));
            MatcherAssert.assertThat("there should be one bundle that has been started but not finished", Integer.valueOf(this.startBundleCalls), Matchers.equalTo(Integer.valueOf(this.finishBundleCalls + 1)));
            MatcherAssert.assertThat("teardown should not have been called", Boolean.valueOf(this.teardownCalled), Matchers.is(false));
            this.finishBundleCalls++;
        }

        @DoFn.Teardown
        public void after() {
            MatcherAssert.assertThat(Boolean.valueOf(this.setupCalled), (Matcher<? super Boolean>) Matchers.is(true));
            MatcherAssert.assertThat(Integer.valueOf(this.startBundleCalls), Matchers.anyOf(Matchers.equalTo(Integer.valueOf(this.finishBundleCalls))));
            MatcherAssert.assertThat(Boolean.valueOf(this.teardownCalled), (Matcher<? super Boolean>) Matchers.is(false));
            this.teardownCalled = true;
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest$CallSequenceEnforcingStatefulFn.class */
    private static class CallSequenceEnforcingStatefulFn<K, V> extends CallSequenceEnforcingFn<KV<K, V>> {
        private static final String STATE_ID = "foo";

        @DoFn.StateId("foo")
        private final StateSpec<ValueState<String>> valueSpec;

        private CallSequenceEnforcingStatefulFn() {
            super();
            this.valueSpec = StateSpecs.value();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest$CallState.class */
    public enum CallState {
        SETUP,
        START_BUNDLE,
        PROCESS_ELEMENT,
        FINISH_BUNDLE,
        TEARDOWN
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest$DelayedCallStateTracker.class */
    public static class DelayedCallStateTracker {
        private final CountDownLatch latch;
        private final AtomicReference<CallState> callState;
        private final List<CallState> callStateVisited;

        private DelayedCallStateTracker(CallState callState) {
            this.callStateVisited = Collections.synchronizedList(new ArrayList());
            this.latch = new CountDownLatch(1);
            this.callState = new AtomicReference<>(callState);
            this.callStateVisited.add(callState);
        }

        DelayedCallStateTracker update(CallState callState) {
            if (this.callState.getAndSet(callState) == CallState.TEARDOWN && callState != CallState.TEARDOWN) {
                Assert.fail("illegal state change from " + this.callState + " to " + callState);
            }
            if (CallState.TEARDOWN == callState) {
                this.latch.countDown();
            }
            synchronized (this.callStateVisited) {
                if (!this.callStateVisited.contains(callState)) {
                    this.callStateVisited.add(callState);
                }
            }
            return this;
        }

        public String toString() {
            return MoreObjects.toStringHelper(this).add("latch", this.latch).add("callState", this.callState).add("callStateVisited", this.callStateVisited).toString();
        }

        CallState callState() {
            return this.callState.get();
        }

        CallState finalState() {
            try {
                this.latch.await(1L, TimeUnit.SECONDS);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            return callState();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest$ExceptionThrowingFn.class */
    public static class ExceptionThrowingFn<T> extends DoFn<T, T> {
        static Map<Integer, DelayedCallStateTracker> callStateMap = new ConcurrentHashMap();
        static AtomicBoolean exceptionWasThrown = new AtomicBoolean(false);
        static AtomicInteger noOfInstancesToTearDown = new AtomicInteger(0);
        private final MethodForException toThrow;
        private boolean thrown;

        private ExceptionThrowingFn(MethodForException methodForException) {
            this.toThrow = methodForException;
        }

        @DoFn.Setup
        public void before() throws Exception {
            MatcherAssert.assertThat("lifecycle methods should not have been called", callStateMap.get(Integer.valueOf(id())), Matchers.is((Matcher) Matchers.nullValue()));
            initCallState();
            noOfInstancesToTearDown.incrementAndGet();
            throwIfNecessary(MethodForException.SETUP);
        }

        @DoFn.StartBundle
        public void preBundle() throws Exception {
            MatcherAssert.assertThat("lifecycle method should have been called before start bundle", getCallState(), Matchers.anyOf(Matchers.equalTo(CallState.SETUP), Matchers.equalTo(CallState.FINISH_BUNDLE)));
            updateCallState(CallState.START_BUNDLE);
            throwIfNecessary(MethodForException.START_BUNDLE);
        }

        @DoFn.ProcessElement
        public void perElement(DoFn<T, T>.ProcessContext processContext) throws Exception {
            MatcherAssert.assertThat("lifecycle method should have been called before processing bundle", getCallState(), Matchers.anyOf(Matchers.equalTo(CallState.START_BUNDLE), Matchers.equalTo(CallState.PROCESS_ELEMENT)));
            updateCallState(CallState.PROCESS_ELEMENT);
            throwIfNecessary(MethodForException.PROCESS_ELEMENT);
        }

        @DoFn.FinishBundle
        public void postBundle() throws Exception {
            MatcherAssert.assertThat("processing bundle or start bundle should have been called before finish bundle", getCallState(), Matchers.anyOf(Matchers.equalTo(CallState.PROCESS_ELEMENT), Matchers.equalTo(CallState.START_BUNDLE)));
            updateCallState(CallState.FINISH_BUNDLE);
            throwIfNecessary(MethodForException.FINISH_BUNDLE);
        }

        private void throwIfNecessary(MethodForException methodForException) throws Exception {
            if (this.toThrow != methodForException || this.thrown) {
                return;
            }
            this.thrown = true;
            exceptionWasThrown.set(true);
            throw new Exception("Hasn't yet thrown");
        }

        @DoFn.Teardown
        public void after() {
            if (noOfInstancesToTearDown.decrementAndGet() == 0 && !exceptionWasThrown.get()) {
                Assert.fail("Expected to have a processing method throw an exception");
            }
            MatcherAssert.assertThat("some lifecycle method should have been called", callStateMap.get(Integer.valueOf(id())), Matchers.is((Matcher) Matchers.notNullValue()));
            updateCallState(CallState.TEARDOWN);
        }

        private void initCallState() {
            if (callStateMap.put(Integer.valueOf(id()), new DelayedCallStateTracker(CallState.SETUP)) != null) {
                Assert.fail(CallState.SETUP + " method called multiple times");
            }
        }

        private int id() {
            return System.identityHashCode(this);
        }

        private void updateCallState(CallState callState) {
            callStateMap.get(Integer.valueOf(id())).update(callState);
        }

        private CallState getCallState() {
            return callStateMap.get(Integer.valueOf(id())).callState();
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest$ExceptionThrowingStatefulFn.class */
    private static class ExceptionThrowingStatefulFn<K, V> extends ExceptionThrowingFn<KV<K, V>> {
        private static final String STATE_ID = "foo";

        @DoFn.StateId("foo")
        private final StateSpec<ValueState<String>> valueSpec;

        private ExceptionThrowingStatefulFn(MethodForException methodForException) {
            super(methodForException);
            this.valueSpec = StateSpecs.value();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/transforms/ParDoLifecycleTest$MethodForException.class */
    public enum MethodForException {
        SETUP,
        START_BUNDLE,
        PROCESS_ELEMENT,
        FINISH_BUNDLE
    }

    @Test
    @Category({ValidatesRunner.class, UsesParDoLifecycle.class})
    public void testFnCallSequence() {
        ((PCollection) PCollectionList.of((PCollection) this.p.apply("Impolite", Create.of(1, 2, 4))).and((PCollection) this.p.apply("Polite", Create.of(3, 5, 6, 7))).apply(Flatten.pCollections())).apply(ParDo.of(new CallSequenceEnforcingFn()));
        this.p.run();
    }

    @Test
    @Category({ValidatesRunner.class, UsesParDoLifecycle.class})
    public void testFnCallSequenceMulti() {
        ((PCollection) PCollectionList.of((PCollection) this.p.apply("Impolite", Create.of(1, 2, 4))).and((PCollection) this.p.apply("Polite", Create.of(3, 5, 6, 7))).apply(Flatten.pCollections())).apply(ParDo.of(new CallSequenceEnforcingFn()).withOutputTags(new TupleTag<Integer>() { // from class: org.apache.beam.sdk.transforms.ParDoLifecycleTest.1
        }, TupleTagList.empty()));
        this.p.run();
    }

    @Test
    @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesParDoLifecycle.class})
    public void testFnCallSequenceStateful() {
        ((PCollection) PCollectionList.of((PCollection) this.p.apply("Impolite", Create.of(KV.of(SampleElements.Strings.MIN_ELEMENT, 1), KV.of("b", 2), KV.of(SampleElements.Strings.MIN_ELEMENT, 4)))).and((PCollection) this.p.apply("Polite", Create.of(KV.of("b", 3), KV.of(SampleElements.Strings.MIN_ELEMENT, 5), KV.of("c", 6), KV.of("c", 7)))).apply(Flatten.pCollections())).apply(ParDo.of(new CallSequenceEnforcingStatefulFn()).withOutputTags(new TupleTag<KV<String, Integer>>() { // from class: org.apache.beam.sdk.transforms.ParDoLifecycleTest.2
        }, TupleTagList.empty()));
        this.p.run();
    }

    @Test
    @Category({ValidatesRunner.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInSetup() {
        ((PCollection) this.p.apply(Create.of(1, 2, 3))).apply(ParDo.of(new ExceptionThrowingFn(MethodForException.SETUP)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.TEARDOWN);
        }
    }

    @Test
    @Category({ValidatesRunner.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInStartBundle() {
        ((PCollection) this.p.apply(Create.of(1, 2, 3))).apply(ParDo.of(new ExceptionThrowingFn(MethodForException.START_BUNDLE)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.START_BUNDLE, CallState.TEARDOWN);
        }
    }

    @Test
    @Category({ValidatesRunner.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInProcessElement() {
        ((PCollection) this.p.apply(Create.of(1, 2, 3))).apply(ParDo.of(new ExceptionThrowingFn(MethodForException.PROCESS_ELEMENT)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.START_BUNDLE, CallState.PROCESS_ELEMENT, CallState.TEARDOWN);
        }
    }

    @Test
    @Category({ValidatesRunner.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInFinishBundle() {
        ((PCollection) this.p.apply(Create.of(1, 2, 3))).apply(ParDo.of(new ExceptionThrowingFn(MethodForException.FINISH_BUNDLE)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.START_BUNDLE, CallState.PROCESS_ELEMENT, CallState.FINISH_BUNDLE, CallState.TEARDOWN);
        }
    }

    @Test
    @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInSetupStateful() {
        ((PCollection) this.p.apply(Create.of(KV.of(SampleElements.Strings.MIN_ELEMENT, 1), KV.of("b", 2), KV.of(SampleElements.Strings.MIN_ELEMENT, 3)))).apply(ParDo.of(new ExceptionThrowingStatefulFn(MethodForException.SETUP)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.TEARDOWN);
        }
    }

    @Test
    @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInStartBundleStateful() {
        ((PCollection) this.p.apply(Create.of(KV.of(SampleElements.Strings.MIN_ELEMENT, 1), KV.of("b", 2), KV.of(SampleElements.Strings.MIN_ELEMENT, 3)))).apply(ParDo.of(new ExceptionThrowingStatefulFn(MethodForException.START_BUNDLE)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.START_BUNDLE, CallState.TEARDOWN);
        }
    }

    @Test
    @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInProcessElementStateful() {
        ((PCollection) this.p.apply(Create.of(KV.of(SampleElements.Strings.MIN_ELEMENT, 1), KV.of("b", 2), KV.of(SampleElements.Strings.MIN_ELEMENT, 3)))).apply(ParDo.of(new ExceptionThrowingStatefulFn(MethodForException.PROCESS_ELEMENT)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.START_BUNDLE, CallState.PROCESS_ELEMENT, CallState.TEARDOWN);
        }
    }

    @Test
    @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesParDoLifecycle.class})
    public void testTeardownCalledAfterExceptionInFinishBundleStateful() {
        ((PCollection) this.p.apply(Create.of(KV.of(SampleElements.Strings.MIN_ELEMENT, 1), KV.of("b", 2), KV.of(SampleElements.Strings.MIN_ELEMENT, 3)))).apply(ParDo.of(new ExceptionThrowingStatefulFn(MethodForException.FINISH_BUNDLE)));
        try {
            this.p.run();
            Assert.fail("Pipeline should have failed with an exception");
        } catch (Exception e) {
            validate(CallState.SETUP, CallState.START_BUNDLE, CallState.PROCESS_ELEMENT, CallState.FINISH_BUNDLE, CallState.TEARDOWN);
        }
    }

    private void validate(CallState... callStateArr) {
        MatcherAssert.assertThat(ExceptionThrowingFn.callStateMap, (Matcher<? super Map<Integer, DelayedCallStateTracker>>) Matchers.is(Matchers.not(Matchers.anEmptyMap())));
        ExceptionThrowingFn.callStateMap.values().forEach(delayedCallStateTracker -> {
            MatcherAssert.assertThat("Function should have been torn down after exception", delayedCallStateTracker.finalState(), Matchers.is(CallState.TEARDOWN));
        });
        List list = (List) Arrays.stream(callStateArr).collect(Collectors.toList());
        MatcherAssert.assertThat("At least one bundle should contain " + list + ", got " + ExceptionThrowingFn.callStateMap.values(), ExceptionThrowingFn.callStateMap.values().stream().anyMatch(delayedCallStateTracker2 -> {
            return delayedCallStateTracker2.callStateVisited.equals(list);
        }));
    }

    @Before
    public void setup() {
        ExceptionThrowingFn.callStateMap = new ConcurrentHashMap();
        ExceptionThrowingFn.exceptionWasThrown.set(false);
    }
}
