package org.apache.sshd.common.forward;

import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import org.apache.sshd.util.test.BaseTestSupport;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/sshd/common/forward/AbstractServerCloseTestSupport.class */
public abstract class AbstractServerCloseTestSupport extends BaseTestSupport {
    private static final String PAYLOAD = String.join("", Collections.nCopies(200, "This is significantly longer Test Data."));
    protected int testServerPort;
    private final Logger log = LoggerFactory.getLogger(getClass());
    private AsynchronousServerSocketChannel testServerSock;

    @Before
    public void startTestServer() throws Exception {
        this.testServerSock = AsynchronousServerSocketChannel.open().bind((SocketAddress) new InetSocketAddress(TEST_LOCALHOST, 0));
        this.testServerPort = ((InetSocketAddress) this.testServerSock.getLocalAddress()).getPort();
        this.log.info("Listening on port {}", Integer.valueOf(this.testServerPort));
        this.testServerSock.accept(this.testServerSock, new CompletionHandler<AsynchronousSocketChannel, AsynchronousServerSocketChannel>() { // from class: org.apache.sshd.common.forward.AbstractServerCloseTestSupport.1
            @Override // java.nio.channels.CompletionHandler
            public void completed(AsynchronousSocketChannel asynchronousSocketChannel, AsynchronousServerSocketChannel asynchronousServerSocketChannel) {
                asynchronousServerSocketChannel.accept(asynchronousServerSocketChannel, this);
                AbstractServerCloseTestSupport.this.log.info("Accepted new incoming connection");
                asynchronousSocketChannel.write(ByteBuffer.wrap(AbstractServerCloseTestSupport.PAYLOAD.getBytes(StandardCharsets.UTF_8)), asynchronousSocketChannel, new CompletionHandler<Integer, AsynchronousSocketChannel>() { // from class: org.apache.sshd.common.forward.AbstractServerCloseTestSupport.1.1
                    @Override // java.nio.channels.CompletionHandler
                    public void completed(Integer num, AsynchronousSocketChannel asynchronousSocketChannel2) {
                        try {
                            asynchronousSocketChannel2.close();
                        } catch (IOException e) {
                            AbstractServerCloseTestSupport.this.log.warn("Failed ({}) to close channel after write complete: {}", e.getClass().getSimpleName(), e.getMessage());
                        }
                    }

                    @Override // java.nio.channels.CompletionHandler
                    public void failed(Throwable th, AsynchronousSocketChannel asynchronousSocketChannel2) {
                        AbstractServerCloseTestSupport.this.log.error("Failed ({}) to write message to client: {}", th.getClass().getSimpleName(), th.getMessage());
                    }
                });
            }

            @Override // java.nio.channels.CompletionHandler
            public void failed(Throwable th, AsynchronousServerSocketChannel asynchronousServerSocketChannel) {
                AbstractServerCloseTestSupport.this.log.error("Failed ({}) to accept incoming connection: {}", th.getClass().getSimpleName(), th.getMessage());
            }
        });
    }

    @After
    public void stopTestServer() throws Exception {
        this.testServerSock.close();
    }

    private void readInLoop(int i) throws Exception {
        outputDebugMessage("readInLoop(port=%d)", Integer.valueOf(i));
        StringBuilder sb = new StringBuilder(PAYLOAD.length());
        try {
            Socket socket = new Socket(TEST_LOCALHOST, i);
            try {
                socket.setSoTimeout(300);
                InputStream inputStream = socket.getInputStream();
                try {
                    byte[] bArr = new byte[PAYLOAD.length() / 10];
                    while (true) {
                        int read = inputStream.read(bArr);
                        if (read == -1) {
                            break;
                        }
                        outputDebugMessage("readInLoop(port=%d) read %d bytes", new Object[]{Integer.valueOf(i), Integer.valueOf(read)});
                        sb.append(new String(bArr, 0, read, StandardCharsets.UTF_8));
                        Thread.sleep(25L);
                    }
                    if (inputStream != null) {
                        inputStream.close();
                    }
                    socket.close();
                } catch (Throwable th) {
                    if (inputStream != null) {
                        try {
                            inputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (IOException e) {
            String sb2 = sb.toString();
            assertEquals("Mismatched data length", PAYLOAD.length(), sb2.length());
            assertEquals("Mismatched read data", PAYLOAD, sb2);
        }
    }

    private void readInOneBuffer(int i) throws Exception {
        outputDebugMessage("readInOneBuffer(port=%d)", Integer.valueOf(i));
        Socket socket = new Socket();
        try {
            socket.setSoTimeout(300);
            socket.setReceiveBufferSize(65536);
            socket.connect(new InetSocketAddress(TEST_LOCALHOST, i));
            Thread.sleep(50L);
            byte[] bArr = new byte[PAYLOAD.length()];
            InputStream inputStream = socket.getInputStream();
            try {
                int read = inputStream.read(bArr);
                outputDebugMessage("readInOneBuffer(port=%d) - Got %d bytes from the server", new Object[]{Integer.valueOf(i), Integer.valueOf(read)});
                assertEquals("Mismatched read data", PAYLOAD, new String(bArr, 0, read, StandardCharsets.UTF_8));
                if (inputStream != null) {
                    inputStream.close();
                }
                socket.close();
            } finally {
            }
        } catch (Throwable th) {
            try {
                socket.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private void readInTwoBuffersWithPause(int i) throws Exception {
        outputDebugMessage("readInTwoBuffersWithPause(port=%d)", Integer.valueOf(i));
        Socket socket = new Socket();
        try {
            socket.setSoTimeout(300);
            socket.setReceiveBufferSize(65536);
            socket.connect(new InetSocketAddress(TEST_LOCALHOST, i));
            Thread.sleep(50L);
            byte[] bArr = new byte[PAYLOAD.length() / 2];
            byte[] bArr2 = new byte[PAYLOAD.length()];
            InputStream inputStream = socket.getInputStream();
            try {
                int read = inputStream.read(bArr);
                outputDebugMessage("readInTwoBuffersWithPause(port=%d) - 1st half is %d bytes", new Object[]{Integer.valueOf(i), Integer.valueOf(read)});
                String str = new String(bArr, 0, read, StandardCharsets.UTF_8);
                Thread.sleep(50L);
                try {
                    int read2 = inputStream.read(bArr2);
                    outputDebugMessage("readInTwoBuffersWithPause(port=%d) - 2nd half is %d bytes", new Object[]{Integer.valueOf(i), Integer.valueOf(read2)});
                    assertEquals("Mismatched read data", PAYLOAD, str + new String(bArr2, 0, read2, StandardCharsets.UTF_8));
                    if (inputStream != null) {
                        inputStream.close();
                    }
                    socket.close();
                } catch (IOException e) {
                    this.log.error("Disconnected ({}) before all data read: {}", e.getClass().getSimpleName(), e.getMessage());
                    throw e;
                }
            } finally {
            }
        } catch (Throwable th) {
            try {
                socket.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    protected abstract int startRemotePF() throws Exception;

    protected abstract int startLocalPF() throws Exception;

    protected boolean hasLocalPFStarted(int i) {
        return true;
    }

    protected boolean hasRemotePFStarted(int i) {
        return true;
    }

    @Test
    public void testRemotePortForwardOneBuffer() throws Exception {
        readInOneBuffer(startRemotePF());
    }

    @Test
    public void testRemotePortForwardTwoBuffers() throws Exception {
        readInTwoBuffersWithPause(startRemotePF());
    }

    @Test
    public void testRemotePortForwardLoop() throws Exception {
        readInLoop(startRemotePF());
    }

    @Test
    public void testLocalPortForwardOneBuffer() throws Exception {
        readInOneBuffer(startLocalPF());
    }

    @Test
    public void testLocalPortForwardTwoBuffers() throws Exception {
        readInTwoBuffersWithPause(startLocalPF());
    }

    @Test
    public void testLocalPortForwardLoop() throws Exception {
        readInLoop(startLocalPF());
    }

    @Test
    public void testHasLocalPortForwardingStarted() throws Exception {
        Assert.assertTrue(hasLocalPFStarted(startLocalPF()));
    }

    @Test
    public void testHasRemotePortForwardingStarted() throws Exception {
        Assert.assertTrue(hasRemotePFStarted(startRemotePF()));
    }
}
