package org.apache.ratis.netty;

import java.io.Closeable;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.ratis.security.SecurityTestUtils;
import org.apache.ratis.security.TlsConf;
import org.apache.ratis.thirdparty.io.netty.bootstrap.Bootstrap;
import org.apache.ratis.thirdparty.io.netty.bootstrap.ServerBootstrap;
import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.ratis.thirdparty.io.netty.buffer.Unpooled;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelFuture;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandler;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelInboundHandler;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelInitializer;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.ratis.thirdparty.io.netty.channel.EventLoopGroup;
import org.apache.ratis.thirdparty.io.netty.channel.nio.NioEventLoopGroup;
import org.apache.ratis.thirdparty.io.netty.channel.socket.SocketChannel;
import org.apache.ratis.thirdparty.io.netty.channel.socket.nio.NioSocketChannel;
import org.apache.ratis.thirdparty.io.netty.handler.logging.LogLevel;
import org.apache.ratis.thirdparty.io.netty.handler.logging.LoggingHandler;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
import org.apache.ratis.util.JavaUtils;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/ratis/netty/TestTlsConfWithNetty.class */
public class TestTlsConfWithNetty {
    private static final Logger LOG = LoggerFactory.getLogger(TestTlsConfWithNetty.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/ratis/netty/TestTlsConfWithNetty$NettyTestClient.class */
    public static class NettyTestClient implements Closeable {
        private static final Logger LOG = LoggerFactory.getLogger(NettyTestClient.class);
        private final ChannelFuture channelFuture;
        private final EventLoopGroup workerGroup = new NioEventLoopGroup(3);
        private final Queue<CompletableFuture<String>> queue = new LinkedList();

        public NettyTestClient(String str, int i, SslContext sslContext) {
            this.channelFuture = new Bootstrap().group(this.workerGroup).channel(NioSocketChannel.class).handler(new LoggingHandler(getClass(), LogLevel.INFO)).handler(newChannelInitializer(sslContext, str, i)).option(ChannelOption.SO_KEEPALIVE, true).option(ChannelOption.TCP_NODELAY, true).connect(str, i).syncUninterruptibly();
        }

        public CompletableFuture<String> writeAndFlush(ByteBuf byteBuf) {
            CompletableFuture<String> completableFuture = new CompletableFuture<>();
            this.queue.offer(completableFuture);
            this.channelFuture.channel().writeAndFlush(byteBuf);
            return completableFuture;
        }

        private ChannelInitializer<SocketChannel> newChannelInitializer(final SslContext sslContext, final String str, final int i) {
            return new ChannelInitializer<SocketChannel>() { // from class: org.apache.ratis.netty.TestTlsConfWithNetty.NettyTestClient.1
                public void initChannel(SocketChannel socketChannel) {
                    ChannelPipeline pipeline = socketChannel.pipeline();
                    if (sslContext != null) {
                        pipeline.addLast("ssl", sslContext.newHandler(socketChannel.alloc(), str, i));
                    }
                    pipeline.addLast(new ChannelHandler[]{NettyTestClient.this.getClientHandler()});
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: private */
        public ChannelInboundHandler getClientHandler() {
            return new ChannelInboundHandlerAdapter() { // from class: org.apache.ratis.netty.TestTlsConfWithNetty.NettyTestClient.2
                public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) {
                    String buffer2String = TestTlsConfWithNetty.buffer2String((ByteBuf) obj);
                    NettyTestClient.LOG.info("received: " + buffer2String);
                    for (String str : buffer2String.split(" ")) {
                        ((CompletableFuture) NettyTestClient.this.queue.remove()).complete(str);
                    }
                }

                public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) {
                    NettyTestClient.LOG.error(NettyTestClient.this.getClass().getSimpleName() + ": exceptionCaught", th);
                    channelHandlerContext.close();
                }
            };
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() {
            this.channelFuture.channel().close();
            this.workerGroup.shutdownGracefully();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/ratis/netty/TestTlsConfWithNetty$NettyTestServer.class */
    public static class NettyTestServer implements Closeable {
        private static final Logger LOG = LoggerFactory.getLogger(NettyTestServer.class);
        static final String CLASS_NAME = JavaUtils.getClassSimpleName(NettyTestServer.class);
        private final EventLoopGroup bossGroup = NettyUtils.newEventLoopGroup(CLASS_NAME + "-bossGroup", 3, true);
        private final EventLoopGroup workerGroup = NettyUtils.newEventLoopGroup(CLASS_NAME + "-workerGroup", 3, true);
        private final ChannelFuture channelFuture;

        public NettyTestServer(int i, SslContext sslContext) {
            this.channelFuture = new ServerBootstrap().group(this.bossGroup, this.workerGroup).channel(NettyUtils.getServerChannelClass(this.bossGroup)).handler(new LoggingHandler(getClass(), LogLevel.INFO)).childHandler(newChannelInitializer(sslContext)).bind(i).syncUninterruptibly();
        }

        private ChannelInitializer<SocketChannel> newChannelInitializer(final SslContext sslContext) {
            return new ChannelInitializer<SocketChannel>() { // from class: org.apache.ratis.netty.TestTlsConfWithNetty.NettyTestServer.1
                public void initChannel(SocketChannel socketChannel) {
                    ChannelPipeline pipeline = socketChannel.pipeline();
                    if (sslContext != null) {
                        pipeline.addLast("ssl", sslContext.newHandler(socketChannel.alloc()));
                    }
                    pipeline.addLast(new ChannelHandler[]{NettyTestServer.this.newServerHandler()});
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: private */
        public ChannelInboundHandler newServerHandler() {
            return new ChannelInboundHandlerAdapter() { // from class: org.apache.ratis.netty.TestTlsConfWithNetty.NettyTestServer.2
                public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) {
                    if (obj instanceof ByteBuf) {
                        String buffer2String = TestTlsConfWithNetty.buffer2String((ByteBuf) obj);
                        NettyTestServer.LOG.info("channelRead: " + buffer2String);
                        for (String str : buffer2String.split(" ")) {
                            channelHandlerContext.writeAndFlush(TestTlsConfWithNetty.unpooledBuffer(NettyTestServer.toReply(str) + " "));
                        }
                    }
                }

                public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) {
                    NettyTestServer.LOG.error(NettyTestServer.this.getClass().getSimpleName() + ": exceptionCaught", th);
                    channelHandlerContext.close();
                }
            };
        }

        static String toReply(String str) {
            return "[" + str + "]";
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() {
            this.channelFuture.channel().close();
            this.bossGroup.shutdownGracefully();
            this.workerGroup.shutdownGracefully();
        }
    }

    static String buffer2String(ByteBuf byteBuf) {
        try {
            return byteBuf.toString(StandardCharsets.UTF_8);
        } finally {
            byteBuf.release();
        }
    }

    static ByteBuf unpooledBuffer(String str) {
        ByteBuf buffer = Unpooled.buffer();
        buffer.writeBytes(str.getBytes(StandardCharsets.UTF_8));
        return buffer;
    }

    static int randomPort() {
        int nextInt = 50000 + ThreadLocalRandom.current().nextInt(10000);
        LOG.info("randomPort: {}", Integer.valueOf(nextInt));
        return nextInt;
    }

    @Test
    public void testNoSsl() throws Exception {
        runTest(randomPort(), null, null);
    }

    @Test
    public void testSsl() throws Exception {
        runTest(randomPort(), SecurityTestUtils.newServerTlsConfig(true), SecurityTestUtils.newClientTlsConfig(true));
    }

    static void runTest(int i, TlsConf tlsConf, TlsConf tlsConf2) throws Exception {
        SslContext buildSslContextForServer = tlsConf == null ? null : NettyUtils.buildSslContextForServer(tlsConf);
        SslContext buildSslContextForClient = tlsConf2 == null ? null : NettyUtils.buildSslContextForClient(tlsConf2);
        String[] split = "Hey, how are you?".split(" ");
        NettyTestServer nettyTestServer = new NettyTestServer(i, buildSslContextForServer);
        Throwable th = null;
        try {
            NettyTestClient nettyTestClient = new NettyTestClient("localhost", i, buildSslContextForClient);
            Throwable th2 = null;
            try {
                try {
                    ArrayList arrayList = new ArrayList();
                    for (String str : split) {
                        arrayList.add(nettyTestClient.writeAndFlush(unpooledBuffer(str + " ")));
                    }
                    for (int i2 = 0; i2 < arrayList.size(); i2++) {
                        String str2 = (String) ((CompletableFuture) arrayList.get(i2)).get(3L, TimeUnit.SECONDS);
                        LOG.info(str2);
                        Assert.assertEquals(NettyTestServer.toReply(split[i2]), str2);
                    }
                    if (nettyTestClient != null) {
                        if (0 != 0) {
                            try {
                                nettyTestClient.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            nettyTestClient.close();
                        }
                    }
                    if (nettyTestServer != null) {
                        if (0 == 0) {
                            nettyTestServer.close();
                            return;
                        }
                        try {
                            nettyTestServer.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    }
                } catch (Throwable th5) {
                    th2 = th5;
                    throw th5;
                }
            } catch (Throwable th6) {
                if (nettyTestClient != null) {
                    if (th2 != null) {
                        try {
                            nettyTestClient.close();
                        } catch (Throwable th7) {
                            th2.addSuppressed(th7);
                        }
                    } else {
                        nettyTestClient.close();
                    }
                }
                throw th6;
            }
        } catch (Throwable th8) {
            if (nettyTestServer != null) {
                if (0 != 0) {
                    try {
                        nettyTestServer.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    nettyTestServer.close();
                }
            }
            throw th8;
        }
    }
}
