package org.apache.beam.runners.spark.translation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.Timer;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
import org.apache.beam.runners.fnexecution.control.InstructionRequestHandler;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.control.TimerReceiverFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.hamcrest.MatcherAssert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

/* loaded from: input_file:org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.class */
public class SparkExecutableStageFunctionTest {

    @Mock
    private SparkExecutableStageContextFactory contextFactory;

    @Mock
    private ExecutableStageContext stageContext;

    @Mock
    private StageBundleFactory stageBundleFactory;

    @Mock
    private RemoteBundle remoteBundle;

    @Mock
    private MetricsContainerStepMapAccumulator metricsAccumulator;

    @Mock
    private MetricsContainerStepMap stepMap;

    @Mock
    private MetricsContainerImpl container;
    private final SerializablePipelineOptions pipelineOptions = new SerializablePipelineOptions(PipelineOptionsFactory.create());
    private final String inputId = "input-id";
    private final RunnerApi.ExecutableStagePayload stagePayload = RunnerApi.ExecutableStagePayload.newBuilder().setInput("input-id").setComponents(RunnerApi.Components.newBuilder().putTransforms("transform-id", RunnerApi.PTransform.newBuilder().putInputs("input-name", "input-id").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:transform:pardo:v1")).build()).putPcollections("input-id", RunnerApi.PCollection.getDefaultInstance()).build()).build();

    @Before
    public void setUpMocks() throws Exception {
        MockitoAnnotations.initMocks(this);
        Mockito.when(this.contextFactory.get((JobInfo) Matchers.any())).thenReturn(this.stageContext);
        Mockito.when(this.stageContext.getStageBundleFactory((ExecutableStage) Matchers.any())).thenReturn(this.stageBundleFactory);
        Mockito.when(this.stageBundleFactory.getBundle((OutputReceiverFactory) Matchers.any(), (TimerReceiverFactory) Matchers.any(), (StateRequestHandler) Matchers.any(), (BundleProgressHandler) Matchers.any(BundleProgressHandler.class))).thenReturn(this.remoteBundle);
        Mockito.when(this.remoteBundle.getInputReceivers()).thenReturn(ImmutableMap.of("input", (FnDataReceiver) Mockito.mock(FnDataReceiver.class)));
        Mockito.when(this.metricsAccumulator.value()).thenReturn(this.stepMap);
        Mockito.when(this.stepMap.getContainer((String) Matchers.any())).thenReturn(this.container);
    }

    @Test(expected = Exception.class)
    public void sdkErrorsSurfaceOnClose() throws Exception {
        SparkExecutableStageFunction function = getFunction(Collections.emptyMap());
        ((RemoteBundle) Mockito.doThrow(new Throwable[]{new Exception()}).when(this.remoteBundle)).close();
        ArrayList arrayList = new ArrayList();
        arrayList.add(WindowedValue.valueInGlobalWindow(0));
        function.call(arrayList.iterator());
    }

    @Test
    public void expectedInputsAreSent() throws Exception {
        SparkExecutableStageFunction function = getFunction(Collections.emptyMap());
        RemoteBundle remoteBundle = (RemoteBundle) Mockito.mock(RemoteBundle.class);
        Mockito.when(this.stageBundleFactory.getBundle((OutputReceiverFactory) Matchers.any(), (TimerReceiverFactory) Matchers.any(), (StateRequestHandler) Matchers.any(), (BundleProgressHandler) Matchers.any(BundleProgressHandler.class))).thenReturn(remoteBundle);
        FnDataReceiver fnDataReceiver = (FnDataReceiver) Mockito.mock(FnDataReceiver.class);
        Mockito.when(remoteBundle.getInputReceivers()).thenReturn(ImmutableMap.of("input-id", fnDataReceiver));
        WindowedValue valueInGlobalWindow = WindowedValue.valueInGlobalWindow(1);
        WindowedValue valueInGlobalWindow2 = WindowedValue.valueInGlobalWindow(2);
        WindowedValue valueInGlobalWindow3 = WindowedValue.valueInGlobalWindow(3);
        function.call(Arrays.asList(valueInGlobalWindow, valueInGlobalWindow2, valueInGlobalWindow3).iterator());
        ((FnDataReceiver) Mockito.verify(fnDataReceiver)).accept(valueInGlobalWindow);
        ((FnDataReceiver) Mockito.verify(fnDataReceiver)).accept(valueInGlobalWindow2);
        ((FnDataReceiver) Mockito.verify(fnDataReceiver)).accept(valueInGlobalWindow3);
        Mockito.verifyNoMoreInteractions(new Object[]{fnDataReceiver});
    }

    @Test
    public void outputsAreTaggedCorrectly() throws Exception {
        final WindowedValue valueInGlobalWindow = WindowedValue.valueInGlobalWindow(3);
        final WindowedValue valueInGlobalWindow2 = WindowedValue.valueInGlobalWindow(4);
        final WindowedValue valueInGlobalWindow3 = WindowedValue.valueInGlobalWindow(5);
        ImmutableMap of = ImmutableMap.of("one", 1, "two", 2, "three", 3);
        Mockito.when(this.stageContext.getStageBundleFactory((ExecutableStage) Matchers.any())).thenReturn(new StageBundleFactory() { // from class: org.apache.beam.runners.spark.translation.SparkExecutableStageFunctionTest.1
            private boolean once;

            public RemoteBundle getBundle(final OutputReceiverFactory outputReceiverFactory, TimerReceiverFactory timerReceiverFactory, StateRequestHandler stateRequestHandler, BundleProgressHandler bundleProgressHandler, BundleFinalizationHandler bundleFinalizationHandler, BundleCheckpointHandler bundleCheckpointHandler) {
                return new RemoteBundle() { // from class: org.apache.beam.runners.spark.translation.SparkExecutableStageFunctionTest.1.1
                    public String getId() {
                        return "bundle-id";
                    }

                    public Map<String, FnDataReceiver> getInputReceivers() {
                        return ImmutableMap.of("input", obj -> {
                        });
                    }

                    public Map<KV<String, String>, FnDataReceiver<Timer>> getTimerReceivers() {
                        return Collections.emptyMap();
                    }

                    public void requestProgress() {
                        throw new UnsupportedOperationException();
                    }

                    public void split(double d) {
                        throw new UnsupportedOperationException();
                    }

                    public void close() throws Exception {
                        if (AnonymousClass1.this.once) {
                            return;
                        }
                        outputReceiverFactory.create("one").accept(valueInGlobalWindow);
                        outputReceiverFactory.create("two").accept(valueInGlobalWindow2);
                        outputReceiverFactory.create("three").accept(valueInGlobalWindow3);
                        AnonymousClass1.this.once = true;
                    }
                };
            }

            public ProcessBundleDescriptors.ExecutableProcessBundleDescriptor getProcessBundleDescriptor() {
                return (ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) Mockito.mock(ProcessBundleDescriptors.ExecutableProcessBundleDescriptor.class);
            }

            public InstructionRequestHandler getInstructionRequestHandler() {
                return null;
            }

            public void close() {
            }
        });
        SparkExecutableStageFunction function = getFunction(of);
        ArrayList arrayList = new ArrayList();
        arrayList.add(WindowedValue.valueInGlobalWindow(0));
        Iterator call = function.call(arrayList.iterator());
        MatcherAssert.assertThat(() -> {
            return call;
        }, org.hamcrest.Matchers.contains(new RawUnionValue[]{new RawUnionValue(1, valueInGlobalWindow), new RawUnionValue(2, valueInGlobalWindow2), new RawUnionValue(3, valueInGlobalWindow3)}));
    }

    @Test
    public void testStageBundleClosed() throws Exception {
        SparkExecutableStageFunction function = getFunction(Collections.emptyMap());
        ArrayList arrayList = new ArrayList();
        arrayList.add(WindowedValue.valueInGlobalWindow(0));
        function.call(arrayList.iterator());
        ((StageBundleFactory) Mockito.verify(this.stageBundleFactory)).getBundle((OutputReceiverFactory) Matchers.any(), (TimerReceiverFactory) Matchers.any(), (StateRequestHandler) Matchers.any(), (BundleProgressHandler) Matchers.any(BundleProgressHandler.class));
        ((StageBundleFactory) Mockito.verify(this.stageBundleFactory)).getProcessBundleDescriptor();
        ((StageBundleFactory) Mockito.verify(this.stageBundleFactory)).close();
        Mockito.verifyNoMoreInteractions(new Object[]{this.stageBundleFactory});
    }

    @Test
    public void testNoCallOnEmptyInputIterator() throws Exception {
        getFunction(Collections.emptyMap()).call(Collections.emptyIterator());
        Mockito.verifyZeroInteractions(new Object[]{this.stageBundleFactory});
    }

    private <InputT, SideInputT> SparkExecutableStageFunction<InputT, SideInputT> getFunction(Map<String, Integer> map) {
        return new SparkExecutableStageFunction<>(this.pipelineOptions, this.stagePayload, (JobInfo) null, map, this.contextFactory, Collections.emptyMap(), this.metricsAccumulator, (Coder) null);
    }
}
