package org.apache.beam.fn.harness.data;

import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.ServerImpl;
import io.grpc.stub.CallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.UUID;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver;
import org.apache.beam.sdk.fn.data.InboundDataClient;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
import org.apache.beam.sdk.fn.stream.StreamObserverFactory;
import org.apache.beam.sdk.fn.test.TestStreams;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.hamcrest.Matchers;
import org.hamcrest.collection.IsEmptyCollection;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.class */
public class BeamFnDataGrpcClientTest {
    private static final Coder<WindowedValue<String>> CODER = LengthPrefixCoder.of(WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE));
    private static final LogicalEndpoint ENDPOINT_A = LogicalEndpoint.of("12L", BeamFnApi.Target.newBuilder().setPrimitiveTransformReference("34L").setName("targetA").build());
    private static final LogicalEndpoint ENDPOINT_B = LogicalEndpoint.of("56L", BeamFnApi.Target.newBuilder().setPrimitiveTransformReference("78L").setName("targetB").build());
    private static final BeamFnApi.Elements ELEMENTS_A_1;
    private static final BeamFnApi.Elements ELEMENTS_A_2;
    private static final BeamFnApi.Elements ELEMENTS_B_1;

    @Test
    public void testForInboundConsumer() throws Exception {
        final CountDownLatch countDownLatch = new CountDownLatch(1);
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        ConcurrentLinkedQueue concurrentLinkedQueue2 = new ConcurrentLinkedQueue();
        ConcurrentLinkedQueue concurrentLinkedQueue3 = new ConcurrentLinkedQueue();
        final AtomicReference atomicReference = new AtomicReference();
        concurrentLinkedQueue3.getClass();
        final CallStreamObserver build = TestStreams.withOnNext((v1) -> {
            r0.add(v1);
        }).build();
        Endpoints.ApiServiceDescriptor build2 = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getClass().getName() + "-" + UUID.randomUUID().toString()).build();
        ServerImpl build3 = InProcessServerBuilder.forName(build2.getUrl()).addService(new BeamFnDataGrpc.BeamFnDataImplBase() { // from class: org.apache.beam.fn.harness.data.BeamFnDataGrpcClientTest.1
            public StreamObserver<BeamFnApi.Elements> data(StreamObserver<BeamFnApi.Elements> streamObserver) {
                atomicReference.set(streamObserver);
                countDownLatch.countDown();
                return build;
            }
        }).build();
        build3.start();
        try {
            ManagedChannel build4 = InProcessChannelBuilder.forName(build2.getUrl()).build();
            BeamFnDataGrpcClient beamFnDataGrpcClient = new BeamFnDataGrpcClient(PipelineOptionsFactory.create(), apiServiceDescriptor -> {
                return build4;
            }, this::createStreamForTest);
            LogicalEndpoint logicalEndpoint = ENDPOINT_A;
            Coder<WindowedValue<String>> coder = CODER;
            concurrentLinkedQueue.getClass();
            InboundDataClient receive = beamFnDataGrpcClient.receive(build2, logicalEndpoint, coder, (v1) -> {
                r4.add(v1);
            });
            countDownLatch.await();
            ((StreamObserver) atomicReference.get()).onNext(ELEMENTS_A_1);
            ((StreamObserver) atomicReference.get()).onNext(ELEMENTS_B_1);
            Thread.sleep(100L);
            LogicalEndpoint logicalEndpoint2 = ENDPOINT_B;
            Coder<WindowedValue<String>> coder2 = CODER;
            concurrentLinkedQueue2.getClass();
            beamFnDataGrpcClient.receive(build2, logicalEndpoint2, coder2, (v1) -> {
                r4.add(v1);
            }).awaitCompletion();
            Assert.assertThat(concurrentLinkedQueue2, Matchers.contains(new WindowedValue[]{WindowedValue.valueInGlobalWindow("JKL"), WindowedValue.valueInGlobalWindow("MNO")}));
            ((StreamObserver) atomicReference.get()).onNext(ELEMENTS_A_2);
            receive.awaitCompletion();
            Assert.assertThat(concurrentLinkedQueue, Matchers.contains(new WindowedValue[]{WindowedValue.valueInGlobalWindow("ABC"), WindowedValue.valueInGlobalWindow("DEF"), WindowedValue.valueInGlobalWindow("GHI")}));
            build3.shutdownNow();
        } catch (Throwable th) {
            build3.shutdownNow();
            throw th;
        }
    }

    @Test
    public void testForInboundConsumerThatThrows() throws Exception {
        final CountDownLatch countDownLatch = new CountDownLatch(1);
        AtomicInteger atomicInteger = new AtomicInteger();
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        final AtomicReference atomicReference = new AtomicReference();
        concurrentLinkedQueue.getClass();
        final CallStreamObserver build = TestStreams.withOnNext((v1) -> {
            r0.add(v1);
        }).build();
        Endpoints.ApiServiceDescriptor build2 = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getClass().getName() + "-" + UUID.randomUUID().toString()).build();
        ServerImpl build3 = InProcessServerBuilder.forName(build2.getUrl()).addService(new BeamFnDataGrpc.BeamFnDataImplBase() { // from class: org.apache.beam.fn.harness.data.BeamFnDataGrpcClientTest.2
            public StreamObserver<BeamFnApi.Elements> data(StreamObserver<BeamFnApi.Elements> streamObserver) {
                atomicReference.set(streamObserver);
                countDownLatch.countDown();
                return build;
            }
        }).build();
        build3.start();
        RuntimeException runtimeException = new RuntimeException("TestFailure");
        try {
            ManagedChannel build4 = InProcessChannelBuilder.forName(build2.getUrl()).build();
            InboundDataClient receive = new BeamFnDataGrpcClient(PipelineOptionsFactory.create(), apiServiceDescriptor -> {
                return build4;
            }, this::createStreamForTest).receive(build2, ENDPOINT_A, CODER, windowedValue -> {
                atomicInteger.incrementAndGet();
                throw runtimeException;
            });
            countDownLatch.await();
            ((StreamObserver) atomicReference.get()).onNext(ELEMENTS_A_1);
            ((StreamObserver) atomicReference.get()).onNext(ELEMENTS_A_2);
            try {
                receive.awaitCompletion();
                Assert.fail("Expected channel to fail");
            } catch (ExecutionException e) {
                Assert.assertEquals(runtimeException, e.getCause());
            }
            Assert.assertThat(concurrentLinkedQueue, IsEmptyCollection.empty());
            Assert.assertEquals(1L, atomicInteger.get());
            build3.shutdownNow();
        } catch (Throwable th) {
            build3.shutdownNow();
            throw th;
        }
    }

    @Test
    public void testForOutboundConsumer() throws Exception {
        CountDownLatch countDownLatch = new CountDownLatch(2);
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        final CallStreamObserver build = TestStreams.withOnNext(elements -> {
            concurrentLinkedQueue.add(elements);
            countDownLatch.countDown();
        }).build();
        Endpoints.ApiServiceDescriptor build2 = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getClass().getName() + "-" + UUID.randomUUID().toString()).build();
        ServerImpl build3 = InProcessServerBuilder.forName(build2.getUrl()).addService(new BeamFnDataGrpc.BeamFnDataImplBase() { // from class: org.apache.beam.fn.harness.data.BeamFnDataGrpcClientTest.3
            public StreamObserver<BeamFnApi.Elements> data(StreamObserver<BeamFnApi.Elements> streamObserver) {
                return build;
            }
        }).build();
        build3.start();
        try {
            ManagedChannel build4 = InProcessChannelBuilder.forName(build2.getUrl()).build();
            CloseableFnDataReceiver send = new BeamFnDataGrpcClient(PipelineOptionsFactory.fromArgs(new String[]{"--experiments=beam_fn_api_data_buffer_limit=20"}).create(), apiServiceDescriptor -> {
                return build4;
            }, this::createStreamForTest).send(build2, ENDPOINT_A, CODER);
            Throwable th = null;
            try {
                try {
                    send.accept(WindowedValue.valueInGlobalWindow("ABC"));
                    send.accept(WindowedValue.valueInGlobalWindow("DEF"));
                    send.accept(WindowedValue.valueInGlobalWindow("GHI"));
                    if (send != null) {
                        if (0 != 0) {
                            try {
                                send.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            send.close();
                        }
                    }
                    countDownLatch.await();
                    Assert.assertThat(concurrentLinkedQueue, Matchers.contains(new BeamFnApi.Elements[]{ELEMENTS_A_1, ELEMENTS_A_2}));
                    build3.shutdownNow();
                } finally {
                }
            } finally {
            }
        } catch (Throwable th3) {
            build3.shutdownNow();
            throw th3;
        }
    }

    private <ReqT, RespT> StreamObserver<RespT> createStreamForTest(StreamObserverFactory.StreamObserverClientFactory<ReqT, RespT> streamObserverClientFactory, StreamObserver<ReqT> streamObserver) {
        return streamObserverClientFactory.outboundObserverFor(streamObserver);
    }

    static {
        try {
            ELEMENTS_A_1 = BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionReference(ENDPOINT_A.getInstructionId()).setTarget(ENDPOINT_A.getTarget()).setData(ByteString.copyFrom(CoderUtils.encodeToByteArray(CODER, WindowedValue.valueInGlobalWindow("ABC"))).concat(ByteString.copyFrom(CoderUtils.encodeToByteArray(CODER, WindowedValue.valueInGlobalWindow("DEF")))))).build();
            ELEMENTS_A_2 = BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionReference(ENDPOINT_A.getInstructionId()).setTarget(ENDPOINT_A.getTarget()).setData(ByteString.copyFrom(CoderUtils.encodeToByteArray(CODER, WindowedValue.valueInGlobalWindow("GHI"))))).addData(BeamFnApi.Elements.Data.newBuilder().setInstructionReference(ENDPOINT_A.getInstructionId()).setTarget(ENDPOINT_A.getTarget())).build();
            ELEMENTS_B_1 = BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionReference(ENDPOINT_B.getInstructionId()).setTarget(ENDPOINT_B.getTarget()).setData(ByteString.copyFrom(CoderUtils.encodeToByteArray(CODER, WindowedValue.valueInGlobalWindow("JKL"))).concat(ByteString.copyFrom(CoderUtils.encodeToByteArray(CODER, WindowedValue.valueInGlobalWindow("MNO")))))).addData(BeamFnApi.Elements.Data.newBuilder().setInstructionReference(ENDPOINT_B.getInstructionId()).setTarget(ENDPOINT_B.getTarget())).build();
        } catch (Exception e) {
            throw new ExceptionInInitializerError(e);
        }
    }
}
