/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.flight;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.function.Consumer;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Result;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestTls {
    @Test
    public void connectTls() {
        this.test(builder -> {
            try (FileInputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
                 FlightClient client = builder.trustedCertificates((InputStream)roots).build();){
                Iterator responses = client.doAction(new Action("hello-world"), new CallOption[0]);
                byte[] response = ((Result)responses.next()).getBody();
                Assertions.assertEquals((Object)"Hello, world!", (Object)new String(response, StandardCharsets.UTF_8));
                Assertions.assertFalse((boolean)responses.hasNext());
            }
            catch (IOException | InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
    }

    @Test
    public void rejectInvalidCert() {
        this.test(builder -> {
            try (FlightClient client = builder.build();){
                Iterator responses = client.doAction(new Action("hello-world"), new CallOption[0]);
                FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> ((Result)responses.next()).getBody());
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
    }

    @Test
    public void rejectHostname() {
        this.test(builder -> {
            try (FileInputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
                 FlightClient client = builder.trustedCertificates((InputStream)roots).overrideHostname("fakehostname").build();){
                Iterator responses = client.doAction(new Action("hello-world"), new CallOption[0]);
                FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> ((Result)responses.next()).getBody());
            }
            catch (IOException | InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
    }

    @Test
    public void connectTlsDisableServerVerification() {
        this.test(builder -> {
            try (FlightClient client = builder.verifyServer(false).build();){
                Iterator responses = client.doAction(new Action("hello-world"), new CallOption[0]);
                byte[] response = ((Result)responses.next()).getBody();
                Assertions.assertEquals((Object)"Hello, world!", (Object)new String(response, StandardCharsets.UTF_8));
                Assertions.assertFalse((boolean)responses.hasNext());
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
    }

    void test(Consumer<FlightClient.Builder> testFn) {
        FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0);
        try (RootAllocator a = new RootAllocator(Long.MAX_VALUE);
             Producer producer = new Producer();
             FlightServer s = FlightServer.builder((BufferAllocator)a, (Location)Location.forGrpcInsecure((String)"localhost", (int)0), (FlightProducer)producer).useTls(certKey.cert, certKey.key).build().start();){
            FlightClient.Builder builder = FlightClient.builder((BufferAllocator)a, (Location)Location.forGrpcTls((String)"localhost", (int)s.getPort()));
            testFn.accept(builder);
        }
        catch (IOException | InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    static class Producer
    extends NoOpFlightProducer
    implements AutoCloseable {
        Producer() {
        }

        public void doAction(FlightProducer.CallContext context, Action action, FlightProducer.StreamListener<Result> listener) {
            if (action.getType().equals("hello-world")) {
                listener.onNext((Object)new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
                listener.onCompleted();
                return;
            }
            listener.onError((Throwable)CallStatus.UNIMPLEMENTED.withDescription("Invalid action " + action.getType()).toRuntimeException());
        }

        @Override
        public void close() {
        }
    }
}

