/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.flight;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.ActionType;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.SyncPutListener;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestServerMiddleware {
    @Test
    public void doPutErrors() {
        TestServerMiddleware.test((FlightProducer)new ErrorProducer(new RuntimeException("test")), (BufferAllocator allocator, FlightClient client) -> {
            FlightDescriptor descriptor = FlightDescriptor.path((String[])new String[]{"test"});
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)new Schema(Collections.emptyList()), (BufferAllocator)allocator);){
                FlightClient.ClientStreamListener listener = client.startPut(descriptor, root, (FlightClient.PutListener)new SyncPutListener(), new CallOption[0]);
                listener.completed();
                FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, () -> ((FlightClient.ClientStreamListener)listener).getResult());
            }
        }, (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertNotNull((Object)status.cause());
            Assertions.assertEquals((Object)FlightStatusCode.INTERNAL, (Object)status.code());
        });
    }

    @Test
    public void doPutCustomCode() {
        TestServerMiddleware.test((FlightProducer)new ErrorProducer((RuntimeException)CallStatus.UNAVAILABLE.withDescription("description").toRuntimeException()), (BufferAllocator allocator, FlightClient client) -> {
            FlightDescriptor descriptor = FlightDescriptor.path((String[])new String[]{"test"});
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)new Schema(Collections.emptyList()), (BufferAllocator)allocator);){
                FlightClient.ClientStreamListener listener = client.startPut(descriptor, root, (FlightClient.PutListener)new SyncPutListener(), new CallOption[0]);
                listener.completed();
                FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> ((FlightClient.ClientStreamListener)listener).getResult());
            }
        }, (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertNull((Object)status.cause());
            Assertions.assertEquals((Object)FlightStatusCode.UNAVAILABLE, (Object)status.code());
            Assertions.assertEquals((Object)"description", (Object)status.description());
        });
    }

    @Test
    public void doPutUncaught() {
        TestServerMiddleware.test((FlightProducer)new ServerErrorProducer(new RuntimeException("test")), (BufferAllocator allocator, FlightClient client) -> {
            FlightDescriptor descriptor = FlightDescriptor.path((String[])new String[]{"test"});
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)new Schema(Collections.emptyList()), (BufferAllocator)allocator);){
                FlightClient.ClientStreamListener listener = client.startPut(descriptor, root, (FlightClient.PutListener)new SyncPutListener(), new CallOption[0]);
                listener.completed();
                listener.getResult();
            }
        }, (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Throwable err = recorder.errFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertEquals((Object)FlightStatusCode.OK, (Object)status.code());
            Assertions.assertNull((Object)status.cause());
            Assertions.assertNotNull((Object)err);
            Assertions.assertEquals((Object)"test", (Object)err.getMessage());
        });
    }

    @Test
    public void listFlightsUncaught() {
        TestServerMiddleware.test((FlightProducer)new ServerErrorProducer(new RuntimeException("test")), (BufferAllocator allocator, FlightClient client) -> client.listFlights(new Criteria(new byte[0]), new CallOption[0]).forEach(action -> {}), (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Throwable err = recorder.errFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertEquals((Object)FlightStatusCode.OK, (Object)status.code());
            Assertions.assertNull((Object)status.cause());
            Assertions.assertNotNull((Object)err);
            Assertions.assertEquals((Object)"test", (Object)err.getMessage());
        });
    }

    @Test
    public void doActionUncaught() {
        TestServerMiddleware.test((FlightProducer)new ServerErrorProducer(new RuntimeException("test")), (BufferAllocator allocator, FlightClient client) -> client.doAction(new Action("test"), new CallOption[0]).forEachRemaining(result -> {}), (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Throwable err = recorder.errFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertEquals((Object)FlightStatusCode.OK, (Object)status.code());
            Assertions.assertNull((Object)status.cause());
            Assertions.assertNotNull((Object)err);
            Assertions.assertEquals((Object)"test", (Object)err.getMessage());
        });
    }

    @Test
    public void listActionsUncaught() {
        TestServerMiddleware.test((FlightProducer)new ServerErrorProducer(new RuntimeException("test")), (BufferAllocator allocator, FlightClient client) -> client.listActions(new CallOption[0]).forEach(result -> {}), (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Throwable err = recorder.errFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertEquals((Object)FlightStatusCode.OK, (Object)status.code());
            Assertions.assertNull((Object)status.cause());
            Assertions.assertNotNull((Object)err);
            Assertions.assertEquals((Object)"test", (Object)err.getMessage());
        });
    }

    @Test
    public void getFlightInfoUncaught() {
        TestServerMiddleware.test((FlightProducer)new ServerErrorProducer(new RuntimeException("test")), (BufferAllocator allocator, FlightClient client) -> FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, () -> client.getInfo(FlightDescriptor.path((String[])new String[]{"test"}), new CallOption[0])), (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertEquals((Object)FlightStatusCode.INTERNAL, (Object)status.code());
            Assertions.assertNotNull((Object)status.cause());
            Assertions.assertEquals((Object)new RuntimeException("test").getMessage(), (Object)status.cause().getMessage());
        });
    }

    @Test
    public void doGetUncaught() {
        TestServerMiddleware.test((FlightProducer)new ServerErrorProducer(new RuntimeException("test")), (BufferAllocator allocator, FlightClient client) -> {
            try (FlightStream stream = client.getStream(new Ticket(new byte[0]), new CallOption[0]);){
                while (stream.next()) {
                }
            }
            catch (Exception e) {
                Assertions.fail((String)e.toString());
            }
        }, (ErrorRecorder recorder) -> {
            CallStatus status = recorder.statusFuture.get();
            Throwable err = recorder.errFuture.get();
            Assertions.assertNotNull((Object)status);
            Assertions.assertEquals((Object)FlightStatusCode.OK, (Object)status.code());
            Assertions.assertNull((Object)status.cause());
            Assertions.assertNotNull((Object)err);
            Assertions.assertEquals((Object)"test", (Object)err.getMessage());
        });
    }

    static <T extends FlightServerMiddleware> void test(FlightProducer producer, List<ServerMiddlewarePair<T>> middleware, BiConsumer<BufferAllocator, FlightClient> body) {
        try (RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);){
            FlightServer server;
            FlightServer.Builder builder = FlightServer.builder((BufferAllocator)allocator, (Location)Location.forGrpcInsecure((String)"localhost", (int)0), (FlightProducer)producer);
            middleware.forEach(pair -> builder.middleware(pair.key, pair.factory));
            try (FlightServer ignored = server = builder.build().start();
                 FlightClient client = FlightClient.builder((BufferAllocator)allocator, (Location)server.getLocation()).build();){
                body.accept((BufferAllocator)allocator, client);
            }
        }
        catch (IOException | InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    static void test(FlightProducer producer, BiConsumer<BufferAllocator, FlightClient> body, ErrorConsumer<ErrorRecorder> verify) {
        ErrorRecorder.Factory factory = new ErrorRecorder.Factory();
        List middleware = Collections.singletonList(new ServerMiddlewarePair<ErrorRecorder>(FlightServerMiddleware.Key.of((String)"m"), factory));
        TestServerMiddleware.test(producer, middleware, (BufferAllocator allocator, FlightClient client) -> {
            body.accept((BufferAllocator)allocator, (FlightClient)client);
            try {
                verify.accept(factory.instance);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    static class ErrorProducer
    extends NoOpFlightProducer {
        final RuntimeException error;

        ErrorProducer(RuntimeException t) {
            this.error = t;
        }

        public Runnable acceptPut(FlightProducer.CallContext context, FlightStream flightStream, FlightProducer.StreamListener<PutResult> ackStream) {
            return () -> {
                while (flightStream.next()) {
                }
                throw this.error;
            };
        }
    }

    @FunctionalInterface
    static interface ErrorConsumer<T> {
        public void accept(T var1) throws Exception;
    }

    static class ServerErrorProducer
    extends NoOpFlightProducer {
        final RuntimeException error;

        ServerErrorProducer(RuntimeException t) {
            this.error = t;
        }

        public void getStream(FlightProducer.CallContext context, Ticket ticket, FlightProducer.ServerStreamListener listener) {
            try (RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
                 VectorSchemaRoot root = VectorSchemaRoot.create((Schema)new Schema(Collections.emptyList()), (BufferAllocator)allocator);){
                listener.start(root);
                listener.completed();
            }
            throw this.error;
        }

        public void listFlights(FlightProducer.CallContext context, Criteria criteria, FlightProducer.StreamListener<FlightInfo> listener) {
            listener.onCompleted();
            throw this.error;
        }

        public FlightInfo getFlightInfo(FlightProducer.CallContext context, FlightDescriptor descriptor) {
            throw this.error;
        }

        public Runnable acceptPut(FlightProducer.CallContext context, FlightStream flightStream, FlightProducer.StreamListener<PutResult> ackStream) {
            return () -> {
                while (flightStream.next()) {
                }
                ackStream.onCompleted();
                throw this.error;
            };
        }

        public void doAction(FlightProducer.CallContext context, Action action, FlightProducer.StreamListener<Result> listener) {
            listener.onCompleted();
            throw this.error;
        }

        public void listActions(FlightProducer.CallContext context, FlightProducer.StreamListener<ActionType> listener) {
            listener.onCompleted();
            throw this.error;
        }
    }

    static class ErrorRecorder
    implements FlightServerMiddleware {
        CompletableFuture<CallStatus> statusFuture = new CompletableFuture();
        CompletableFuture<Throwable> errFuture = new CompletableFuture();

        ErrorRecorder() {
        }

        public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
        }

        public void onCallCompleted(CallStatus status) {
            this.statusFuture.complete(status);
        }

        public void onCallErrored(Throwable err) {
            this.errFuture.complete(err);
        }

        static class Factory
        implements FlightServerMiddleware.Factory<ErrorRecorder> {
            ErrorRecorder instance = new ErrorRecorder();

            Factory() {
            }

            public ErrorRecorder onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
                return this.instance;
            }
        }
    }

    static class ServerMiddlewarePair<T extends FlightServerMiddleware> {
        final FlightServerMiddleware.Key<T> key;
        final FlightServerMiddleware.Factory<T> factory;

        ServerMiddlewarePair(FlightServerMiddleware.Key<T> key, FlightServerMiddleware.Factory<T> factory) {
            this.key = key;
            this.factory = factory;
        }
    }
}

