package org.apache.beam.runners.fnexecution.control;

import java.io.Serializable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import org.apache.beam.fn.harness.FnHarness;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.repackaged.beam_runners_java_fn_execution.com.google.common.base.Optional;
import org.apache.beam.repackaged.beam_runners_java_fn_execution.com.google.common.base.Preconditions;
import org.apache.beam.repackaged.beam_runners_java_fn_execution.com.google.common.collect.ImmutableMap;
import org.apache.beam.repackaged.beam_runners_java_fn_execution.com.google.common.collect.Iterables;
import org.apache.beam.repackaged.beam_runners_java_fn_execution.com.google.common.collect.Iterators;
import org.apache.beam.repackaged.beam_runners_java_fn_execution.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.FusedPipeline;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.runners.fnexecution.GrpcContextHeaderAccessorProvider;
import org.apache.beam.runners.fnexecution.GrpcFnServer;
import org.apache.beam.runners.fnexecution.InProcessServerFactory;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.runners.fnexecution.control.SdkHarnessClient;
import org.apache.beam.runners.fnexecution.data.GrpcDataService;
import org.apache.beam.runners.fnexecution.logging.GrpcLoggingService;
import org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter;
import org.apache.beam.runners.fnexecution.state.GrpcStateService;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.BigEndianLongCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.ByteString;
import org.hamcrest.Matchers;
import org.hamcrest.collection.IsEmptyIterable;
import org.hamcrest.collection.IsIterableContainingInOrder;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.class */
public class RemoteExecutionTest implements Serializable {
    private transient GrpcFnServer<FnApiControlClientPoolService> controlServer;
    private transient GrpcFnServer<GrpcDataService> dataServer;
    private transient GrpcFnServer<GrpcStateService> stateServer;
    private transient GrpcFnServer<GrpcLoggingService> loggingServer;
    private transient GrpcStateService stateDelegator;
    private transient SdkHarnessClient controlClient;
    private transient ExecutorService serverExecutor;
    private transient ExecutorService sdkHarnessExecutor;
    private transient Future<?> sdkHarnessExecutorFuture;

    @Before
    public void setup() throws Exception {
        ThreadFactory build = new ThreadFactoryBuilder().setDaemon(true).build();
        this.serverExecutor = Executors.newCachedThreadPool(build);
        InProcessServerFactory create = InProcessServerFactory.create();
        this.dataServer = GrpcFnServer.allocatePortAndCreateFor(GrpcDataService.create(this.serverExecutor, OutboundObserverFactory.serverDirect()), create);
        this.loggingServer = GrpcFnServer.allocatePortAndCreateFor(GrpcLoggingService.forWriter(Slf4jLogWriter.getDefault()), create);
        this.stateDelegator = GrpcStateService.create();
        this.stateServer = GrpcFnServer.allocatePortAndCreateFor(this.stateDelegator, create);
        MapControlClientPool create2 = MapControlClientPool.create();
        this.controlServer = GrpcFnServer.allocatePortAndCreateFor(FnApiControlClientPoolService.offeringClientsToPool(create2.getSink(), GrpcContextHeaderAccessorProvider.getHeaderAccessor()), create);
        this.sdkHarnessExecutor = Executors.newSingleThreadExecutor(build);
        this.sdkHarnessExecutorFuture = this.sdkHarnessExecutor.submit(() -> {
            try {
                FnHarness.main("id", PipelineOptionsFactory.create(), this.loggingServer.getApiServiceDescriptor(), this.controlServer.getApiServiceDescriptor(), InProcessManagedChannelFactory.create(), OutboundObserverFactory.clientDirect());
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
        this.controlClient = SdkHarnessClient.usingFnApiClient(create2.getSource().take("", Duration.ofSeconds(2L)), this.dataServer.getService());
    }

    @After
    public void tearDown() throws Exception {
        this.controlServer.close();
        this.stateServer.close();
        this.dataServer.close();
        this.loggingServer.close();
        this.controlClient.close();
        this.sdkHarnessExecutor.shutdownNow();
        this.serverExecutor.shutdownNow();
        try {
            this.sdkHarnessExecutorFuture.get();
        } catch (ExecutionException e) {
            if (!(e.getCause() instanceof RuntimeException) || !(e.getCause().getCause() instanceof InterruptedException)) {
                throw e;
            }
        }
    }

    @Test
    public void testExecution() throws Exception {
        Pipeline create = Pipeline.create();
        create.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.1
            @DoFn.ProcessElement
            public void process(DoFn<byte[], String>.ProcessContext processContext) {
                processContext.output("zero");
                processContext.output("one");
                processContext.output("two");
            }
        })).apply("len", ParDo.of(new DoFn<String, Long>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.2
            @DoFn.ProcessElement
            public void process(DoFn<String, Long>.ProcessContext processContext) {
                processContext.output(Long.valueOf(((String) processContext.element()).length()));
            }
        })).apply("addKeys", WithKeys.of("foo")).setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianLongCoder.of())).apply("gbk", GroupByKey.create());
        FusedPipeline fuse = GreedyPipelineFuser.fuse(PipelineTranslation.toProto(create));
        Preconditions.checkState(fuse.getFusedStages().size() == 1, "Expected exactly one fused stage");
        ProcessBundleDescriptors.ExecutableProcessBundleDescriptor fromExecutableStage = ProcessBundleDescriptors.fromExecutableStage("my_stage", (ExecutableStage) fuse.getFusedStages().iterator().next(), this.dataServer.getApiServiceDescriptor());
        SdkHarnessClient.BundleProcessor processor = this.controlClient.getProcessor(fromExecutableStage.getProcessBundleDescriptor(), fromExecutableStage.getRemoteInputDestination());
        Map outputTargetCoders = fromExecutableStage.getOutputTargetCoders();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry : outputTargetCoders.entrySet()) {
            List synchronizedList = Collections.synchronizedList(new ArrayList());
            hashMap.put((BeamFnApi.Target) entry.getKey(), synchronizedList);
            BeamFnApi.Target target = (BeamFnApi.Target) entry.getKey();
            Coder coder = (Coder) entry.getValue();
            Objects.requireNonNull(synchronizedList);
            hashMap2.put(target, RemoteOutputReceiver.of(coder, (v1) -> {
                r3.add(v1);
            }));
        }
        SdkHarnessClient.ActiveBundle newBundle = processor.newBundle(hashMap2, BundleProgressHandler.unsupported());
        Throwable th = null;
        try {
            try {
                newBundle.getInputReceiver().accept(WindowedValue.valueInGlobalWindow(new byte[0]));
                if (newBundle != null) {
                    $closeResource(null, newBundle);
                }
                Iterator it = hashMap.values().iterator();
                while (it.hasNext()) {
                    Assert.assertThat((Collection) it.next(), Matchers.containsInAnyOrder(new Object[]{WindowedValue.valueInGlobalWindow(kvBytes("foo", 4L)), WindowedValue.valueInGlobalWindow(kvBytes("foo", 3L)), WindowedValue.valueInGlobalWindow(kvBytes("foo", 3L))}));
                }
            } finally {
            }
        } catch (Throwable th2) {
            if (newBundle != null) {
                $closeResource(th, newBundle);
            }
            throw th2;
        }
    }

    /* JADX WARN: Type inference failed for: r0v44, types: [byte[], java.lang.Object[]] */
    @Test
    public void testExecutionWithSideInput() throws Exception {
        Pipeline create = Pipeline.create();
        PCollection coder = create.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.3
            @DoFn.ProcessElement
            public void process(DoFn<byte[], String>.ProcessContext processContext) {
                processContext.output("zero");
                processContext.output("one");
                processContext.output("two");
            }
        })).setCoder(StringUtf8Coder.of());
        final PCollectionView apply = coder.apply("createSideInput", View.asIterable());
        coder.apply("readSideInput", ParDo.of(new DoFn<String, KV<String, String>>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.4
            @DoFn.ProcessElement
            public void processElement(DoFn<String, KV<String, String>>.ProcessContext processContext) {
                Iterator it = ((Iterable) processContext.sideInput(apply)).iterator();
                while (it.hasNext()) {
                    processContext.output(KV.of((String) processContext.element(), (String) it.next()));
                }
            }
        }).withSideInputs(new PCollectionView[]{apply})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
        Optional tryFind = Iterables.tryFind(GreedyPipelineFuser.fuse(PipelineTranslation.toProto(create)).getFusedStages(), executableStage -> {
            return !executableStage.getSideInputs().isEmpty();
        });
        Preconditions.checkState(tryFind.isPresent(), "Expected a stage with side inputs.");
        ProcessBundleDescriptors.ExecutableProcessBundleDescriptor fromExecutableStage = ProcessBundleDescriptors.fromExecutableStage("test_stage", (ExecutableStage) tryFind.get(), this.dataServer.getApiServiceDescriptor(), this.stateServer.getApiServiceDescriptor());
        SdkHarnessClient.BundleProcessor processor = this.controlClient.getProcessor(fromExecutableStage.getProcessBundleDescriptor(), fromExecutableStage.getRemoteInputDestination(), this.stateDelegator);
        Map outputTargetCoders = fromExecutableStage.getOutputTargetCoders();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry : outputTargetCoders.entrySet()) {
            List synchronizedList = Collections.synchronizedList(new ArrayList());
            hashMap.put((BeamFnApi.Target) entry.getKey(), synchronizedList);
            BeamFnApi.Target target = (BeamFnApi.Target) entry.getKey();
            Coder coder2 = (Coder) entry.getValue();
            Objects.requireNonNull(synchronizedList);
            hashMap2.put(target, RemoteOutputReceiver.of(coder2, (v1) -> {
                r3.add(v1);
            }));
        }
        final List asList = Arrays.asList(new byte[]{CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A"), CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B"), CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C")});
        SdkHarnessClient.ActiveBundle newBundle = processor.newBundle(hashMap2, StateRequestHandlers.forSideInputHandlerFactory(fromExecutableStage.getSideInputSpecs(), new StateRequestHandlers.SideInputHandlerFactory() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.5
            public <T, V, W extends BoundedWindow> StateRequestHandlers.SideInputHandler<V, W> forSideInput(String str, String str2, RunnerApi.FunctionSpec functionSpec, final Coder<T> coder3, Coder<W> coder4) {
                return (StateRequestHandlers.SideInputHandler<V, W>) new StateRequestHandlers.SideInputHandler<V, W>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.5.1
                    /* JADX WARN: Incorrect types in method signature: ([BTW;)Ljava/lang/Iterable<TV;>; */
                    public Iterable get(byte[] bArr, BoundedWindow boundedWindow) {
                        return asList;
                    }

                    public Coder<V> resultCoder() {
                        return coder3.getValueCoder();
                    }
                };
            }
        }), BundleProgressHandler.unsupported());
        Throwable th = null;
        try {
            try {
                newBundle.getInputReceiver().accept(WindowedValue.valueInGlobalWindow(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
                newBundle.getInputReceiver().accept(WindowedValue.valueInGlobalWindow(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y")));
                if (newBundle != null) {
                    $closeResource(null, newBundle);
                }
                Iterator it = hashMap.values().iterator();
                while (it.hasNext()) {
                    Assert.assertThat((Collection) it.next(), Matchers.containsInAnyOrder(new WindowedValue[]{WindowedValue.valueInGlobalWindow(kvBytes("X", "A")), WindowedValue.valueInGlobalWindow(kvBytes("X", "B")), WindowedValue.valueInGlobalWindow(kvBytes("X", "C")), WindowedValue.valueInGlobalWindow(kvBytes("Y", "A")), WindowedValue.valueInGlobalWindow(kvBytes("Y", "B")), WindowedValue.valueInGlobalWindow(kvBytes("Y", "C"))}));
                }
            } finally {
            }
        } catch (Throwable th2) {
            if (newBundle != null) {
                $closeResource(th, newBundle);
            }
            throw th2;
        }
    }

    @Test
    public void testExecutionWithUserState() throws Exception {
        Pipeline create = Pipeline.create();
        create.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.6
            @DoFn.ProcessElement
            public void process(DoFn<byte[], KV<String, String>>.ProcessContext processContext) {
            }
        })).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.7

            @DoFn.StateId("foo")
            private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());

            @DoFn.StateId("foo2")
            private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());

            @DoFn.ProcessElement
            public void processElement(@DoFn.Element KV<String, String> kv, @DoFn.StateId("foo") BagState<String> bagState, @DoFn.StateId("foo2") BagState<String> bagState2, DoFn.OutputReceiver<KV<String, String>> outputReceiver) {
                bagState.isEmpty();
                Iterator it = bagState.read().iterator();
                while (it.hasNext()) {
                    outputReceiver.output(KV.of((String) kv.getKey(), (String) it.next()));
                }
                bagState.add((String) kv.getValue());
                bagState2.clear();
            }
        })).apply("gbk", GroupByKey.create());
        Optional tryFind = Iterables.tryFind(GreedyPipelineFuser.fuse(PipelineTranslation.toProto(create)).getFusedStages(), executableStage -> {
            return !executableStage.getUserStates().isEmpty();
        });
        Preconditions.checkState(tryFind.isPresent(), "Expected a stage with user state.");
        ProcessBundleDescriptors.ExecutableProcessBundleDescriptor fromExecutableStage = ProcessBundleDescriptors.fromExecutableStage("test_stage", (ExecutableStage) tryFind.get(), this.dataServer.getApiServiceDescriptor(), this.stateServer.getApiServiceDescriptor());
        SdkHarnessClient.BundleProcessor processor = this.controlClient.getProcessor(fromExecutableStage.getProcessBundleDescriptor(), fromExecutableStage.getRemoteInputDestination(), this.stateDelegator);
        Map outputTargetCoders = fromExecutableStage.getOutputTargetCoders();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry : outputTargetCoders.entrySet()) {
            List synchronizedList = Collections.synchronizedList(new ArrayList());
            hashMap.put((BeamFnApi.Target) entry.getKey(), synchronizedList);
            BeamFnApi.Target target = (BeamFnApi.Target) entry.getKey();
            Coder coder = (Coder) entry.getValue();
            Objects.requireNonNull(synchronizedList);
            hashMap2.put(target, RemoteOutputReceiver.of(coder, (v1) -> {
                r3.add(v1);
            }));
        }
        final ImmutableMap of = ImmutableMap.of("foo", new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), "foo2", new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
        SdkHarnessClient.ActiveBundle newBundle = processor.newBundle(hashMap2, StateRequestHandlers.forBagUserStateHandlerFactory(fromExecutableStage, new StateRequestHandlers.BagUserStateHandlerFactory() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.8
            public <K, V, W extends BoundedWindow> StateRequestHandlers.BagUserStateHandler<K, V, W> forUserState(String str, final String str2, Coder<K> coder2, Coder<V> coder3, Coder<W> coder4) {
                return (StateRequestHandlers.BagUserStateHandler<K, V, W>) new StateRequestHandlers.BagUserStateHandler<K, V, W>() { // from class: org.apache.beam.runners.fnexecution.control.RemoteExecutionTest.8.1
                    /* JADX WARN: Incorrect types in method signature: (TK;TW;)Ljava/lang/Iterable<TV;>; */
                    public Iterable get(Object obj, BoundedWindow boundedWindow) {
                        return (Iterable) of.get(str2);
                    }

                    /* JADX WARN: Incorrect types in method signature: (TK;TW;Ljava/util/Iterator<TV;>;)V */
                    public void append(Object obj, BoundedWindow boundedWindow, Iterator it) {
                        Iterators.addAll((Collection) of.get(str2), it);
                    }

                    /* JADX WARN: Incorrect types in method signature: (TK;TW;)V */
                    public void clear(Object obj, BoundedWindow boundedWindow) {
                        ((List) of.get(str2)).clear();
                    }
                };
            }
        }), BundleProgressHandler.unsupported());
        Throwable th = null;
        try {
            try {
                newBundle.getInputReceiver().accept(WindowedValue.valueInGlobalWindow(kvBytes("X", "Y")));
                if (newBundle != null) {
                    $closeResource(null, newBundle);
                }
                Iterator it = hashMap.values().iterator();
                while (it.hasNext()) {
                    Assert.assertThat((Collection) it.next(), Matchers.containsInAnyOrder(new WindowedValue[]{WindowedValue.valueInGlobalWindow(kvBytes("X", "A")), WindowedValue.valueInGlobalWindow(kvBytes("X", "B")), WindowedValue.valueInGlobalWindow(kvBytes("X", "C"))}));
                }
                Assert.assertThat((List) of.get("foo"), IsIterableContainingInOrder.contains(new ByteString[]{ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y", Coder.Context.NESTED))}));
                Assert.assertThat((List) of.get("foo2"), IsEmptyIterable.emptyIterable());
            } finally {
            }
        } catch (Throwable th2) {
            if (newBundle != null) {
                $closeResource(th, newBundle);
            }
            throw th2;
        }
    }

    private KV<byte[], byte[]> kvBytes(String str, long j) throws CoderException {
        return KV.of(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), str), CoderUtils.encodeToByteArray(BigEndianLongCoder.of(), Long.valueOf(j)));
    }

    private KV<byte[], byte[]> kvBytes(String str, String str2) throws CoderException {
        return KV.of(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), str), CoderUtils.encodeToByteArray(StringUtf8Coder.of(), str2));
    }

    private static /* synthetic */ void $closeResource(Throwable th, AutoCloseable autoCloseable) {
        if (th == null) {
            autoCloseable.close();
            return;
        }
        try {
            autoCloseable.close();
        } catch (Throwable th2) {
            th.addSuppressed(th2);
        }
    }
}
