package org.apache.beam.runners.flink.streaming;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContext;
import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
import org.apache.beam.runners.flink.translation.wrappers.streaming.ExecutableStageDoFnOperator;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.Struct;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.util.OutputTag;
import org.hamcrest.collection.IsIterableContainingInOrder;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.internal.util.reflection.Whitebox;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.class */
public class ExecutableStageDoFnOperatorTest {

    @Mock
    private RuntimeContext runtimeContext;

    @Mock
    private DistributedCache distributedCache;

    @Mock
    private FlinkExecutableStageContext stageContext;

    @Mock
    private StageBundleFactory stageBundleFactory;

    @Mock
    private StateRequestHandler stateRequestHandler;

    @Mock
    private ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor;

    @Rule
    public ExpectedException thrown = ExpectedException.none();
    private final RunnerApi.ExecutableStagePayload stagePayload = RunnerApi.ExecutableStagePayload.newBuilder().setInput("input").setComponents(RunnerApi.Components.newBuilder().putPcollections("input", RunnerApi.PCollection.getDefaultInstance()).build()).build();
    private final JobInfo jobInfo = JobInfo.create("job-id", "job-name", "retrieval-token", Struct.getDefaultInstance());

    @Before
    public void setUpMocks() {
        MockitoAnnotations.initMocks(this);
        Mockito.when(this.runtimeContext.getDistributedCache()).thenReturn(this.distributedCache);
        Mockito.when(this.stageContext.getStageBundleFactory((ExecutableStage) Matchers.any())).thenReturn(this.stageBundleFactory);
        Mockito.when(this.processBundleDescriptor.getTimerSpecs()).thenReturn(Collections.emptyMap());
        Mockito.when(this.stageBundleFactory.getProcessBundleDescriptor()).thenReturn(this.processBundleDescriptor);
    }

    @Test
    public void sdkErrorsSurfaceOnClose() throws Exception {
        TupleTag<Integer> tupleTag = new TupleTag<>("main-output");
        ExecutableStageDoFnOperator<Integer, Integer> operator = getOperator(tupleTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>(tupleTag, VoidCoder.of()));
        OneInputStreamOperatorTestHarness oneInputStreamOperatorTestHarness = new OneInputStreamOperatorTestHarness(operator);
        oneInputStreamOperatorTestHarness.open();
        RemoteBundle remoteBundle = (RemoteBundle) Mockito.mock(RemoteBundle.class);
        Mockito.when(this.stageBundleFactory.getBundle((OutputReceiverFactory) Matchers.any(), (StateRequestHandler) Matchers.any(), (BundleProgressHandler) Matchers.any())).thenReturn(remoteBundle);
        Mockito.when(remoteBundle.getInputReceivers()).thenReturn(ImmutableMap.of("input", (FnDataReceiver) Mockito.mock(FnDataReceiver.class)));
        RuntimeException runtimeException = new RuntimeException(new Exception());
        ((RemoteBundle) Mockito.doThrow(runtimeException).when(remoteBundle)).close();
        this.thrown.expectCause(org.hamcrest.Matchers.is(runtimeException));
        operator.processElement(new StreamRecord(WindowedValue.valueInGlobalWindow(0)));
        oneInputStreamOperatorTestHarness.close();
    }

    @Test
    public void expectedInputsAreSent() throws Exception {
        TupleTag<Integer> tupleTag = new TupleTag<>("main-output");
        ExecutableStageDoFnOperator<Integer, Integer> operator = getOperator(tupleTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>(tupleTag, VoidCoder.of()));
        RemoteBundle remoteBundle = (RemoteBundle) Mockito.mock(RemoteBundle.class);
        Mockito.when(this.stageBundleFactory.getBundle((OutputReceiverFactory) Matchers.any(), (StateRequestHandler) Matchers.any(), (BundleProgressHandler) Matchers.any())).thenReturn(remoteBundle);
        FnDataReceiver fnDataReceiver = (FnDataReceiver) Mockito.mock(FnDataReceiver.class);
        Mockito.when(remoteBundle.getInputReceivers()).thenReturn(ImmutableMap.of("input", fnDataReceiver));
        WindowedValue valueInGlobalWindow = WindowedValue.valueInGlobalWindow(1);
        WindowedValue valueInGlobalWindow2 = WindowedValue.valueInGlobalWindow(2);
        WindowedValue valueInGlobalWindow3 = WindowedValue.valueInGlobalWindow(3);
        OneInputStreamOperatorTestHarness oneInputStreamOperatorTestHarness = new OneInputStreamOperatorTestHarness(operator);
        oneInputStreamOperatorTestHarness.open();
        oneInputStreamOperatorTestHarness.processElement(new StreamRecord(valueInGlobalWindow));
        oneInputStreamOperatorTestHarness.processElement(new StreamRecord(valueInGlobalWindow2));
        oneInputStreamOperatorTestHarness.processElement(new StreamRecord(valueInGlobalWindow3));
        ((FnDataReceiver) Mockito.verify(fnDataReceiver)).accept(valueInGlobalWindow);
        ((FnDataReceiver) Mockito.verify(fnDataReceiver)).accept(valueInGlobalWindow2);
        ((FnDataReceiver) Mockito.verify(fnDataReceiver)).accept(valueInGlobalWindow3);
        Mockito.verifyNoMoreInteractions(new Object[]{fnDataReceiver});
        oneInputStreamOperatorTestHarness.close();
    }

    @Test
    public void outputsAreTaggedCorrectly() throws Exception {
        WindowedValue.ValueOnlyWindowedValueCoder valueOnlyCoder = WindowedValue.getValueOnlyCoder(VarIntCoder.of());
        final TupleTag<Integer> tupleTag = new TupleTag<>("main-output");
        final TupleTag tupleTag2 = new TupleTag("output-1");
        final TupleTag tupleTag3 = new TupleTag("output-2");
        ImmutableMap build = ImmutableMap.builder().put(tupleTag2, new OutputTag<String>(tupleTag2.getId()) { // from class: org.apache.beam.runners.flink.streaming.ExecutableStageDoFnOperatorTest.2
        }).put(tupleTag3, new OutputTag<String>(tupleTag3.getId()) { // from class: org.apache.beam.runners.flink.streaming.ExecutableStageDoFnOperatorTest.1
        }).build();
        DoFnOperator.MultiOutputOutputManagerFactory<Integer> multiOutputOutputManagerFactory = new DoFnOperator.MultiOutputOutputManagerFactory<>(tupleTag, build, ImmutableMap.builder().put(tupleTag, valueOnlyCoder).put(tupleTag2, valueOnlyCoder).put(tupleTag3, valueOnlyCoder).build(), ImmutableMap.builder().put(tupleTag, 0).put(tupleTag2, 1).put(tupleTag3, 2).build());
        WindowedValue valueInGlobalWindow = WindowedValue.valueInGlobalWindow(0);
        final WindowedValue valueInGlobalWindow2 = WindowedValue.valueInGlobalWindow(3);
        final WindowedValue valueInGlobalWindow3 = WindowedValue.valueInGlobalWindow(4);
        final WindowedValue valueInGlobalWindow4 = WindowedValue.valueInGlobalWindow(5);
        Mockito.when(this.stageContext.getStageBundleFactory((ExecutableStage) Matchers.any())).thenReturn(new StageBundleFactory() { // from class: org.apache.beam.runners.flink.streaming.ExecutableStageDoFnOperatorTest.3
            private boolean onceEmitted;

            public RemoteBundle getBundle(final OutputReceiverFactory outputReceiverFactory, StateRequestHandler stateRequestHandler, BundleProgressHandler bundleProgressHandler) {
                return new RemoteBundle() { // from class: org.apache.beam.runners.flink.streaming.ExecutableStageDoFnOperatorTest.3.1
                    public String getId() {
                        return "bundle-id";
                    }

                    public Map<String, FnDataReceiver<WindowedValue<?>>> getInputReceivers() {
                        return ImmutableMap.of("input", windowedValue -> {
                        });
                    }

                    public void close() throws Exception {
                        if (AnonymousClass3.this.onceEmitted) {
                            return;
                        }
                        outputReceiverFactory.create(tupleTag.getId()).accept(valueInGlobalWindow2);
                        outputReceiverFactory.create(tupleTag2.getId()).accept(valueInGlobalWindow3);
                        outputReceiverFactory.create(tupleTag3.getId()).accept(valueInGlobalWindow4);
                        AnonymousClass3.this.onceEmitted = true;
                    }
                };
            }

            public ProcessBundleDescriptors.ExecutableProcessBundleDescriptor getProcessBundleDescriptor() {
                return ExecutableStageDoFnOperatorTest.this.processBundleDescriptor;
            }

            public void close() {
            }
        });
        OneInputStreamOperatorTestHarness oneInputStreamOperatorTestHarness = new OneInputStreamOperatorTestHarness(getOperator(tupleTag, ImmutableList.of(tupleTag2, tupleTag3), multiOutputOutputManagerFactory));
        long currentWatermark = oneInputStreamOperatorTestHarness.getCurrentWatermark() + 1;
        oneInputStreamOperatorTestHarness.open();
        oneInputStreamOperatorTestHarness.processElement(new StreamRecord(valueInGlobalWindow));
        oneInputStreamOperatorTestHarness.processWatermark(currentWatermark);
        long j = currentWatermark + 1;
        oneInputStreamOperatorTestHarness.processWatermark(j);
        Assert.assertEquals(j, oneInputStreamOperatorTestHarness.getCurrentWatermark());
        Assert.assertEquals(0L, oneInputStreamOperatorTestHarness.getOutput().size());
        oneInputStreamOperatorTestHarness.close();
        Assert.assertThat(oneInputStreamOperatorTestHarness.getOutput(), IsIterableContainingInOrder.contains(new Object[]{new StreamRecord(valueInGlobalWindow2), new Watermark(j), new Watermark(Long.MAX_VALUE)}));
        Assert.assertThat(oneInputStreamOperatorTestHarness.getSideOutput((OutputTag) build.get(tupleTag2)), IsIterableContainingInOrder.contains(new StreamRecord[]{new StreamRecord(valueInGlobalWindow3)}));
        Assert.assertThat(oneInputStreamOperatorTestHarness.getSideOutput((OutputTag) build.get(tupleTag3)), IsIterableContainingInOrder.contains(new StreamRecord[]{new StreamRecord(valueInGlobalWindow4)}));
    }

    @Test
    public void testStageBundleClosed() throws Exception {
        TupleTag<Integer> tupleTag = new TupleTag<>("main-output");
        ExecutableStageDoFnOperator<Integer, Integer> operator = getOperator(tupleTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>(tupleTag, VoidCoder.of()));
        OneInputStreamOperatorTestHarness oneInputStreamOperatorTestHarness = new OneInputStreamOperatorTestHarness(operator);
        RemoteBundle remoteBundle = (RemoteBundle) Mockito.mock(RemoteBundle.class);
        Mockito.when(remoteBundle.getInputReceivers()).thenReturn(ImmutableMap.builder().put("input", (FnDataReceiver) Mockito.mock(FnDataReceiver.class)).build());
        Mockito.when(this.stageBundleFactory.getBundle((OutputReceiverFactory) Matchers.any(), (StateRequestHandler) Matchers.any(), (BundleProgressHandler) Matchers.any())).thenReturn(remoteBundle);
        oneInputStreamOperatorTestHarness.open();
        oneInputStreamOperatorTestHarness.close();
        ((StageBundleFactory) Mockito.verify(this.stageBundleFactory)).getProcessBundleDescriptor();
        ((StageBundleFactory) Mockito.verify(this.stageBundleFactory)).close();
        ((FlinkExecutableStageContext) Mockito.verify(this.stageContext)).close();
        ((StageBundleFactory) Mockito.verify(this.stageBundleFactory)).getBundle((OutputReceiverFactory) Matchers.any(), (StateRequestHandler) Matchers.any(), (BundleProgressHandler) Matchers.any());
        ((RemoteBundle) Mockito.verify(remoteBundle)).getInputReceivers();
        ((RemoteBundle) Mockito.verify(remoteBundle)).close();
        Mockito.verifyNoMoreInteractions(new Object[]{this.stageBundleFactory});
        operator.dispose();
        Mockito.verifyNoMoreInteractions(new Object[]{remoteBundle});
    }

    @Test
    public void testSerialization() {
        WindowedValue.ValueOnlyWindowedValueCoder valueOnlyCoder = WindowedValue.getValueOnlyCoder(VarIntCoder.of());
        TupleTag tupleTag = new TupleTag("main-output");
        TupleTag tupleTag2 = new TupleTag("additional-output");
        DoFnOperator.MultiOutputOutputManagerFactory multiOutputOutputManagerFactory = new DoFnOperator.MultiOutputOutputManagerFactory(tupleTag, ImmutableMap.builder().put(tupleTag2, new OutputTag(tupleTag2.getId(), TypeInformation.of(Integer.class))).build(), ImmutableMap.builder().put(tupleTag, valueOnlyCoder).put(tupleTag2, valueOnlyCoder).build(), ImmutableMap.builder().put(tupleTag, 0).put(tupleTag2, 1).build());
        FlinkPipelineOptions as = PipelineOptionsFactory.as(FlinkPipelineOptions.class);
        ExecutableStageDoFnOperator executableStageDoFnOperator = new ExecutableStageDoFnOperator("transform", (Coder) null, (Coder) null, Collections.emptyMap(), tupleTag, ImmutableList.of(tupleTag2), multiOutputOutputManagerFactory, Collections.emptyMap(), Collections.emptyList(), Collections.emptyMap(), as, this.stagePayload, this.jobInfo, FlinkExecutableStageContext.factory(as), createOutputMap(tupleTag, ImmutableList.of(tupleTag2)), WindowingStrategy.globalDefault(), (Coder) null, (KeySelector) null);
        ExecutableStageDoFnOperator clone = SerializationUtils.clone(executableStageDoFnOperator);
        Assert.assertNotNull(clone);
        Assert.assertNotEquals(executableStageDoFnOperator, clone);
    }

    private ExecutableStageDoFnOperator<Integer, Integer> getOperator(TupleTag<Integer> tupleTag, List<TupleTag<?>> list, DoFnOperator.MultiOutputOutputManagerFactory<Integer> multiOutputOutputManagerFactory) {
        FlinkExecutableStageContext.Factory factory = (FlinkExecutableStageContext.Factory) Mockito.mock(FlinkExecutableStageContext.Factory.class);
        Mockito.when(factory.get((JobInfo) Matchers.any())).thenReturn(this.stageContext);
        ExecutableStageDoFnOperator<Integer, Integer> executableStageDoFnOperator = new ExecutableStageDoFnOperator<>("transform", (Coder) null, (Coder) null, Collections.emptyMap(), tupleTag, list, multiOutputOutputManagerFactory, Collections.emptyMap(), Collections.emptyList(), Collections.emptyMap(), PipelineOptionsFactory.as(FlinkPipelineOptions.class), this.stagePayload, this.jobInfo, factory, createOutputMap(tupleTag, list), WindowingStrategy.globalDefault(), (Coder) null, (KeySelector) null);
        Whitebox.setInternalState(executableStageDoFnOperator, "stateRequestHandler", this.stateRequestHandler);
        return executableStageDoFnOperator;
    }

    private static Map<String, TupleTag<?>> createOutputMap(TupleTag tupleTag, List<TupleTag<?>> list) {
        HashMap hashMap = new HashMap(list.size() + 1);
        if (tupleTag != null) {
            hashMap.put(tupleTag.getId(), tupleTag);
        }
        for (TupleTag<?> tupleTag2 : list) {
            hashMap.put(tupleTag2.getId(), tupleTag2);
        }
        return hashMap;
    }
}
