/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.flight;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.flight.TestServerMiddleware;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestClientMiddleware {
    static Map<String, List<byte[]>> EXPECTED_BINARY_HEADERS = new HashMap<String, List<byte[]>>();
    static Map<String, List<String>> EXPECTED_TEXT_HEADERS = new HashMap<String, List<String>>();

    @Test
    public void clientMiddleware_failCallBeforeSending() {
        TestClientMiddleware.test((FlightProducer)new NoOpFlightProducer(), null, Collections.singletonList(new CallRejector.Factory()), (allocator, client) -> FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> client.listActions(new CallOption[0])));
    }

    @Test
    public void middleware_propagateHeader() {
        Context context = new Context("span id");
        TestClientMiddleware.test((FlightProducer)new NoOpFlightProducer(), new TestServerMiddleware.ServerMiddlewarePair<ServerSpanInjector>(FlightServerMiddleware.Key.of((String)"test"), new ServerSpanInjector.Factory()), Collections.singletonList(new ClientSpanInjector.Factory(context)), (allocator, client) -> FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> client.listActions(new CallOption[0]).forEach(actionType -> {})));
        Assertions.assertEquals((Object)context.outgoingSpanId, (Object)context.incomingSpanId);
        Assertions.assertNotNull((Object)context.finalStatus);
        Assertions.assertEquals((Object)FlightStatusCode.UNIMPLEMENTED, (Object)context.finalStatus.code());
    }

    @Test
    public void testMultiValuedHeaders() {
        MultiHeaderClientMiddlewareFactory clientFactory = new MultiHeaderClientMiddlewareFactory();
        TestClientMiddleware.test((FlightProducer)new NoOpFlightProducer(), new TestServerMiddleware.ServerMiddlewarePair<MultiHeaderServerMiddleware>(FlightServerMiddleware.Key.of((String)"test"), new MultiHeaderServerMiddlewareFactory()), Collections.singletonList(clientFactory), (allocator, client) -> FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> client.listActions(new CallOption[0]).forEach(actionType -> {})));
        for (Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) {
            List<byte[]> receivedValues = clientFactory.lastBinaryHeaders.get(entry.getKey());
            Assertions.assertNotNull(receivedValues, (String)("Missing for header: " + entry.getKey()));
            Assertions.assertEquals((int)entry.getValue().size(), (int)receivedValues.size(), (String)("Missing or wrong value for header: " + entry.getKey()));
            for (int i = 0; i < entry.getValue().size(); ++i) {
                Assertions.assertArrayEquals((byte[])entry.getValue().get(i), (byte[])receivedValues.get(i));
            }
        }
        for (Map.Entry<String, List<Object>> entry : EXPECTED_TEXT_HEADERS.entrySet()) {
            Assertions.assertEquals(entry.getValue(), clientFactory.lastTextHeaders.get(entry.getKey()), (String)("Missing or wrong value for header: " + entry.getKey()));
        }
    }

    private static <T extends FlightServerMiddleware> void test(FlightProducer producer, TestServerMiddleware.ServerMiddlewarePair<T> serverMiddleware, List<FlightClientMiddleware.Factory> clientMiddleware, BiConsumer<BufferAllocator, FlightClient> body) {
        try (RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);){
            FlightServer.Builder serverBuilder = FlightServer.builder((BufferAllocator)allocator, (Location)Location.forGrpcInsecure((String)"localhost", (int)0), (FlightProducer)producer);
            if (serverMiddleware != null) {
                serverBuilder.middleware(serverMiddleware.key, serverMiddleware.factory);
            }
            FlightServer server = serverBuilder.build().start();
            FlightClient.Builder clientBuilder = FlightClient.builder((BufferAllocator)allocator, (Location)server.getLocation());
            clientMiddleware.forEach(arg_0 -> ((FlightClient.Builder)clientBuilder).intercept(arg_0));
            try (FlightServer ignored = server;
                 FlightClient client = clientBuilder.build();){
                body.accept((BufferAllocator)allocator, client);
            }
        }
        catch (IOException | InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    static {
        EXPECTED_BINARY_HEADERS.put("x-binary-bin", Arrays.asList({0}, {1}));
        EXPECTED_TEXT_HEADERS.put("x-text", Arrays.asList("foo", "bar"));
    }

    static class CallRejector
    implements FlightClientMiddleware {
        CallRejector() {
        }

        public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
        }

        public void onHeadersReceived(CallHeaders incomingHeaders) {
        }

        public void onCallCompleted(CallStatus status) {
        }

        static class Factory
        implements FlightClientMiddleware.Factory {
            Factory() {
            }

            public FlightClientMiddleware onCallStarted(CallInfo info) {
                throw CallStatus.UNAVAILABLE.withDescription("Rejecting call.").toRuntimeException();
            }
        }
    }

    static class Context {
        final String outgoingSpanId;
        String incomingSpanId;
        CallStatus finalStatus;

        Context(String spanId) {
            this.outgoingSpanId = spanId;
        }
    }

    static class ServerSpanInjector
    implements FlightServerMiddleware {
        private final String spanId;

        public ServerSpanInjector(String spanId) {
            this.spanId = spanId;
        }

        public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
            outgoingHeaders.insert("x-span", this.spanId);
        }

        public void onCallCompleted(CallStatus status) {
        }

        public void onCallErrored(Throwable err) {
        }

        static class Factory
        implements FlightServerMiddleware.Factory<ServerSpanInjector> {
            Factory() {
            }

            public ServerSpanInjector onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
                return new ServerSpanInjector(incomingHeaders.get("x-span"));
            }
        }
    }

    static class ClientSpanInjector
    implements FlightClientMiddleware {
        private final Context context;

        public ClientSpanInjector(Context context) {
            this.context = context;
        }

        public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
            outgoingHeaders.insert("x-span", this.context.outgoingSpanId);
        }

        public void onHeadersReceived(CallHeaders incomingHeaders) {
            this.context.incomingSpanId = incomingHeaders.get("x-span");
        }

        public void onCallCompleted(CallStatus status) {
            this.context.finalStatus = status;
        }

        static class Factory
        implements FlightClientMiddleware.Factory {
            private final Context context;

            Factory(Context context) {
                this.context = context;
            }

            public FlightClientMiddleware onCallStarted(CallInfo info) {
                return new ClientSpanInjector(this.context);
            }
        }
    }

    static class MultiHeaderClientMiddlewareFactory
    implements FlightClientMiddleware.Factory {
        Map<String, List<byte[]>> lastBinaryHeaders = null;
        Map<String, List<String>> lastTextHeaders = null;

        MultiHeaderClientMiddlewareFactory() {
        }

        public FlightClientMiddleware onCallStarted(CallInfo info) {
            return new MultiHeaderClientMiddleware(this);
        }
    }

    static class MultiHeaderServerMiddlewareFactory
    implements FlightServerMiddleware.Factory<MultiHeaderServerMiddleware> {
        MultiHeaderServerMiddlewareFactory() {
        }

        public MultiHeaderServerMiddleware onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
            HashMap<String, List<byte[]>> binaryHeaders = new HashMap<String, List<byte[]>>();
            HashMap<String, List<String>> textHeaders = new HashMap<String, List<String>>();
            for (String key : incomingHeaders.keys()) {
                if (key.endsWith("-bin")) {
                    binaryHeaders.compute(key, (ignored, values) -> {
                        if (values == null) {
                            values = new ArrayList();
                        }
                        incomingHeaders.getAllByte(key).forEach(values::add);
                        return values;
                    });
                    continue;
                }
                textHeaders.compute(key, (ignored, values) -> {
                    if (values == null) {
                        values = new ArrayList();
                    }
                    incomingHeaders.getAll(key).forEach(values::add);
                    return values;
                });
            }
            return new MultiHeaderServerMiddleware(binaryHeaders, textHeaders);
        }
    }

    static class MultiHeaderClientMiddleware
    implements FlightClientMiddleware {
        private final MultiHeaderClientMiddlewareFactory factory;

        public MultiHeaderClientMiddleware(MultiHeaderClientMiddlewareFactory factory) {
            this.factory = factory;
        }

        public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
            for (Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) {
                entry.getValue().forEach(value -> outgoingHeaders.insert((String)entry.getKey(), value));
                Assertions.assertTrue((boolean)outgoingHeaders.containsKey(entry.getKey()));
            }
            for (Map.Entry<String, List<Object>> entry : EXPECTED_TEXT_HEADERS.entrySet()) {
                entry.getValue().forEach(value -> outgoingHeaders.insert((String)entry.getKey(), value));
                Assertions.assertTrue((boolean)outgoingHeaders.containsKey(entry.getKey()));
            }
        }

        public void onHeadersReceived(CallHeaders incomingHeaders) {
            this.factory.lastBinaryHeaders = new HashMap<String, List<byte[]>>();
            this.factory.lastTextHeaders = new HashMap<String, List<String>>();
            incomingHeaders.keys().forEach(header -> {
                if (header.endsWith("-bin")) {
                    ArrayList values = new ArrayList();
                    incomingHeaders.getAllByte(header).forEach(values::add);
                    this.factory.lastBinaryHeaders.put((String)header, values);
                } else {
                    ArrayList values = new ArrayList();
                    incomingHeaders.getAll(header).forEach(values::add);
                    this.factory.lastTextHeaders.put((String)header, values);
                }
            });
        }

        public void onCallCompleted(CallStatus status) {
        }
    }

    static class MultiHeaderServerMiddleware
    implements FlightServerMiddleware {
        private final Map<String, List<byte[]>> binaryHeaders;
        private final Map<String, List<String>> textHeaders;

        MultiHeaderServerMiddleware(Map<String, List<byte[]>> binaryHeaders, Map<String, List<String>> textHeaders) {
            this.binaryHeaders = binaryHeaders;
            this.textHeaders = textHeaders;
        }

        public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
            this.binaryHeaders.forEach((key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value)));
            this.textHeaders.forEach((key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value)));
        }

        public void onCallCompleted(CallStatus status) {
        }

        public void onCallErrored(Throwable err) {
        }
    }
}

