/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.driver.jdbc;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayDeque;
import java.util.Properties;
import org.apache.arrow.driver.jdbc.ArrowFlightJdbcConnectionPoolDataSource;
import org.apache.arrow.driver.jdbc.ArrowFlightJdbcDataSource;
import org.apache.arrow.driver.jdbc.authentication.Authentication;
import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication;
import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl;
import org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FlightServerTestRule
implements TestRule,
AutoCloseable {
    public static final String DEFAULT_USER = "flight-test-user";
    public static final String DEFAULT_PASSWORD = "flight-test-password";
    private static final Logger LOGGER = LoggerFactory.getLogger(FlightServerTestRule.class);
    private final Properties properties;
    private final ArrowFlightConnectionConfigImpl config;
    private final BufferAllocator allocator;
    private final FlightSqlProducer producer;
    private final Authentication authentication;
    private final FlightSqlTestCertificates.CertKeyPair certKeyPair;
    private final File mTlsCACert;
    private final MiddlewareCookie.Factory middlewareCookieFactory = new MiddlewareCookie.Factory();

    private FlightServerTestRule(Properties properties, ArrowFlightConnectionConfigImpl config, BufferAllocator allocator, FlightSqlProducer producer, Authentication authentication, FlightSqlTestCertificates.CertKeyPair certKeyPair, File mTlsCACert) {
        this.properties = (Properties)Preconditions.checkNotNull((Object)properties);
        this.config = (ArrowFlightConnectionConfigImpl)Preconditions.checkNotNull((Object)config);
        this.allocator = (BufferAllocator)Preconditions.checkNotNull((Object)allocator);
        this.producer = (FlightSqlProducer)Preconditions.checkNotNull((Object)producer);
        this.authentication = authentication;
        this.certKeyPair = certKeyPair;
        this.mTlsCACert = mTlsCACert;
    }

    public static FlightServerTestRule createStandardTestRule(FlightSqlProducer producer) {
        UserPasswordAuthentication authentication = new UserPasswordAuthentication.Builder().user(DEFAULT_USER, DEFAULT_PASSWORD).build();
        return new Builder().authentication(authentication).producer(producer).build();
    }

    ArrowFlightJdbcDataSource createDataSource() {
        return ArrowFlightJdbcDataSource.createNewDataSource((Properties)this.properties);
    }

    public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource() {
        return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource((Properties)this.properties);
    }

    public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource(boolean useEncryption) {
        this.setUseEncryption(useEncryption);
        return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource((Properties)this.properties);
    }

    public Connection getConnection(boolean useEncryption, String token) throws SQLException {
        this.properties.put("token", token);
        return this.getConnection(useEncryption);
    }

    public Connection getConnection(boolean useEncryption) throws SQLException {
        this.setUseEncryption(useEncryption);
        return this.createDataSource().getConnection();
    }

    private void setUseEncryption(boolean useEncryption) {
        this.properties.put("useEncryption", (Object)useEncryption);
    }

    public MiddlewareCookie.Factory getMiddlewareCookieFactory() {
        return this.middlewareCookieFactory;
    }

    private FlightServer initiateServer(Location location) throws IOException {
        FlightServer.Builder builder = FlightServer.builder((BufferAllocator)this.allocator, (Location)location, (FlightProducer)this.producer).headerAuthenticator(this.authentication.authenticate()).middleware(FlightServerMiddleware.Key.of((String)"KEY"), (FlightServerMiddleware.Factory)this.middlewareCookieFactory);
        if (this.certKeyPair != null) {
            builder.useTls(this.certKeyPair.cert, this.certKeyPair.key);
        }
        if (this.mTlsCACert != null) {
            builder.useMTlsClientVerification(this.mTlsCACert);
        }
        return builder.build();
    }

    public Statement apply(final Statement base, Description description) {
        return new Statement(){

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            public void evaluate() throws Throwable {
                try (FlightServer flightServer = FlightServerTestRule.this.getStartServer(location -> FlightServerTestRule.this.initiateServer(location), 3);){
                    FlightServerTestRule.this.properties.put("port", (Object)flightServer.getPort());
                    LOGGER.info("Started " + FlightServer.class.getName() + " as " + flightServer);
                    base.evaluate();
                }
                finally {
                    FlightServerTestRule.this.close();
                }
            }
        };
    }

    private FlightServer getStartServer(CheckedFunction<Location, FlightServer> newServerFromLocation, int retries) throws IOException {
        ArrayDeque<ReflectiveOperationException> exceptions = new ArrayDeque<ReflectiveOperationException>();
        while (retries > 0) {
            FlightServer server = newServerFromLocation.apply(Location.forGrpcInsecure((String)"localhost", (int)0));
            try {
                Method start = server.getClass().getMethod("start", new Class[0]);
                start.setAccessible(true);
                start.invoke((Object)server, new Object[0]);
                return server;
            }
            catch (ReflectiveOperationException e2) {
                exceptions.add(e2);
                --retries;
            }
        }
        exceptions.forEach(e -> LOGGER.error("Failed to start FlightServer", (Throwable)e));
        throw new IOException(((ReflectiveOperationException)exceptions.pop()).getCause());
    }

    public int getPort() {
        return this.config.getPort();
    }

    public String getHost() {
        return this.config.getHost();
    }

    @Override
    public void close() throws Exception {
        this.allocator.getChildAllocators().forEach(BufferAllocator::close);
        AutoCloseables.close((AutoCloseable[])new AutoCloseable[]{this.allocator});
    }

    static class MiddlewareCookie
    implements FlightServerMiddleware {
        private final Factory factory;

        public MiddlewareCookie(Factory factory) {
            this.factory = factory;
        }

        public void onBeforeSendingHeaders(CallHeaders callHeaders) {
            if (!this.factory.receivedCookieHeader) {
                callHeaders.insert("Set-Cookie", "k=v");
            }
        }

        public void onCallCompleted(CallStatus callStatus) {
        }

        public void onCallErrored(Throwable throwable) {
        }

        static class Factory
        implements FlightServerMiddleware.Factory<MiddlewareCookie> {
            private boolean receivedCookieHeader = false;
            private String cookie;

            Factory() {
            }

            public MiddlewareCookie onCallStarted(CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
                this.cookie = callHeaders.get("Cookie");
                this.receivedCookieHeader = null != this.cookie;
                return new MiddlewareCookie(this);
            }

            public String getCookie() {
                return this.cookie;
            }
        }
    }

    public static final class Builder {
        private final Properties properties = new Properties();
        private FlightSqlProducer producer;
        private Authentication authentication;
        private FlightSqlTestCertificates.CertKeyPair certKeyPair;
        private File mTlsCACert;

        public Builder() {
            this.properties.put("host", "localhost");
        }

        public Builder producer(FlightSqlProducer producer) {
            this.producer = producer;
            return this;
        }

        public Builder authentication(Authentication authentication) {
            this.authentication = authentication;
            return this;
        }

        public Builder useEncryption(File certChain, File key) {
            this.certKeyPair = new FlightSqlTestCertificates.CertKeyPair(certChain, key);
            return this;
        }

        public Builder useMTlsClientVerification(File mTlsCACert) {
            this.mTlsCACert = mTlsCACert;
            return this;
        }

        public FlightServerTestRule build() {
            this.authentication.populateProperties(this.properties);
            return new FlightServerTestRule(this.properties, new ArrowFlightConnectionConfigImpl(this.properties), (BufferAllocator)new RootAllocator(Long.MAX_VALUE), this.producer, this.authentication, this.certKeyPair, this.mTlsCACert);
        }
    }

    @FunctionalInterface
    public static interface CheckedFunction<T, R> {
        public R apply(T var1) throws IOException;
    }
}

