package org.apache.sshd.common.forward;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.future.ConnectFuture;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.future.CancelOption;
import org.apache.sshd.common.util.net.SshdSocketAddress;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider;
import org.apache.sshd.util.test.BaseTestSupport;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/sshd/common/forward/ConcurrentConnectionTest.class */
public class ConcurrentConnectionTest extends BaseTestSupport {
    private static final int SSHD_NIO_WORKERS = 8;
    private static final int PORT_FORWARD_CLIENT_COUNT = 12;
    private static int sshServerPort;
    private static SshServer server;
    private int testServerPort;
    private ServerSocket testServerSock;
    private Thread testServerThread;
    private ClientSession session;
    private static final byte[] PAYLOAD_TO_SERVER = "To Server -> To Server -> To Server".getBytes();
    private static final byte[] PAYLOAD_TO_CLIENT = "<- To Client <- To Client <-".getBytes();
    private static final Logger LOG = LoggerFactory.getLogger(ConcurrentConnectionTest.class);
    private static final int SO_TIMEOUT = (int) TimeUnit.SECONDS.toMillis(10);

    @Before
    public void startTestServer() throws Exception {
        this.testServerThread = new Thread(this::serverAcceptLoop);
        this.testServerThread.setDaemon(true);
        this.testServerThread.setName("Server Acceptor");
        this.testServerThread.start();
        Thread.sleep(100L);
    }

    protected void serverAcceptLoop() {
        try {
            AtomicInteger atomicInteger = new AtomicInteger(0);
            this.testServerSock = new ServerSocket(0);
            this.testServerPort = this.testServerSock.getLocalPort();
            LOG.debug("Listening on {}", Integer.valueOf(this.testServerPort));
            while (true) {
                Socket accept = this.testServerSock.accept();
                LOG.debug("Got connection");
                Thread thread = new Thread(() -> {
                    serverSocketLoop(atomicInteger, accept);
                });
                thread.setDaemon(true);
                thread.setName("Server " + accept.getPort());
                thread.start();
            }
        } catch (SocketException e) {
            LOG.debug("Shutting down test server");
        } catch (Throwable th) {
            LOG.error("Error", th);
        }
    }

    private void serverSocketLoop(AtomicInteger atomicInteger, Socket socket) {
        try {
            LOG.debug("Active Servers: {}", Integer.valueOf(atomicInteger.incrementAndGet()));
            LOG.debug("Read {} payload from client", Long.valueOf(socket.getInputStream().read(new byte[PAYLOAD_TO_SERVER.length])));
            socket.getOutputStream().write(PAYLOAD_TO_CLIENT);
            LOG.debug("Wrote payload to client");
            socket.close();
        } catch (Throwable th) {
            LOG.error("Error", th);
        }
        LOG.debug("Active Servers: {}", Integer.valueOf(atomicInteger.decrementAndGet()));
    }

    @After
    public void stopTestServer() throws Exception {
        this.testServerSock.close();
        this.testServerThread.interrupt();
    }

    @BeforeClass
    public static void startSshServer() throws IOException {
        LOG.debug("Starting SSHD...");
        server = SshServer.setUpDefaultServer();
        server.setPasswordAuthenticator((str, str2, serverSession) -> {
            return true;
        });
        server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider());
        server.setNioWorkers(SSHD_NIO_WORKERS);
        server.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
        server.start();
        sshServerPort = server.getPort();
        LOG.debug("SSHD Running on port {}", Integer.valueOf(server.getPort()));
    }

    @AfterClass
    public static void stopServer() throws IOException {
        if (server.close(true).await(CLOSE_TIMEOUT, new CancelOption[0])) {
            return;
        }
        LOG.warn("Failed to close server within {} sec.", Long.valueOf(CLOSE_TIMEOUT.toMillis() / 1000));
    }

    @Before
    public void createClient() throws IOException {
        SshClient upDefaultClient = SshClient.setUpDefaultClient();
        upDefaultClient.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
        upDefaultClient.start();
        LOG.debug("Connecting...");
        this.session = ((ConnectFuture) upDefaultClient.connect("user", TEST_LOCALHOST, sshServerPort).verify(CONNECT_TIMEOUT, new CancelOption[0])).getSession();
        LOG.debug("Authenticating...");
        this.session.addPasswordIdentity("foo");
        this.session.auth().verify(AUTH_TIMEOUT, new CancelOption[0]);
        LOG.debug("Authenticated");
    }

    @After
    public void stopClient() throws Exception {
        LOG.debug("Disconnecting Client");
        try {
            assertTrue("Failed to close session", this.session.close(true).await(CLOSE_TIMEOUT, new CancelOption[0]));
        } finally {
            this.session = null;
        }
    }

    @Test
    public void testConcurrentConnectionsToPortForward() throws Exception {
        int port = this.session.startRemotePortForwarding(new SshdSocketAddress(TEST_LOCALHOST, 0), new SshdSocketAddress(TEST_LOCALHOST, this.testServerPort)).getPort();
        CyclicBarrier cyclicBarrier = new CyclicBarrier(PORT_FORWARD_CLIENT_COUNT, () -> {
            LOG.debug("And away we go.");
        });
        AtomicInteger atomicInteger = new AtomicInteger(0);
        AtomicInteger atomicInteger2 = new AtomicInteger(0);
        CountDownLatch countDownLatch = new CountDownLatch(PORT_FORWARD_CLIENT_COUNT);
        long[] jArr = new long[PORT_FORWARD_CLIENT_COUNT];
        for (int i = 0; i < PORT_FORWARD_CLIENT_COUNT; i++) {
            long j = 100 * i;
            int i2 = i;
            Thread thread = new Thread(() -> {
                try {
                    jArr[i2] = makeClientRequest(port, cyclicBarrier, j);
                    LOG.debug("Complete, received full payload from server.");
                    atomicInteger.incrementAndGet();
                } catch (Exception e) {
                    atomicInteger2.incrementAndGet();
                    LOG.error("Error in client code", e);
                }
                countDownLatch.countDown();
            });
            thread.setName("Client " + i);
            thread.setDaemon(true);
            thread.start();
        }
        assertTrue("All threads should be done after two minutes", countDownLatch.await(2L, TimeUnit.MINUTES));
        for (int i3 = 0; i3 < PORT_FORWARD_CLIENT_COUNT; i3++) {
            assertEquals("Mismatched data length read from server for client " + i3, PAYLOAD_TO_CLIENT.length, jArr[i3]);
        }
        assertEquals("Not all clients succeeded", 12L, atomicInteger.get());
    }

    private long makeClientRequest(int i, CyclicBarrier cyclicBarrier, long j) throws Exception {
        outputDebugMessage("readInLoop(port=%d)", Integer.valueOf(i));
        Socket socket = new Socket();
        socket.setSoTimeout(SO_TIMEOUT);
        cyclicBarrier.await();
        socket.connect(new InetSocketAddress(TEST_LOCALHOST, i));
        socket.getOutputStream().write(PAYLOAD_TO_SERVER);
        long read = socket.getInputStream().read(new byte[PAYLOAD_TO_CLIENT.length]);
        LOG.debug("Read {} payload from server", Long.valueOf(read));
        assertEquals("Mismatched data length", PAYLOAD_TO_CLIENT.length, read);
        socket.close();
        return read;
    }
}
