package org.apache.sshd.common.forward;

import io.grpc.Server;
import io.grpc.ServerBuilder;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import org.apache.sshd.common.session.ConnectionService;
import org.apache.sshd.common.util.net.SshdSocketAddress;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.channel.ChannelSessionFactory;
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
import org.apache.sshd.server.forward.DirectTcpipFactory;
import org.apache.sshd.server.forward.ForwardedTcpipFactory;
import org.apache.sshd.util.test.BaseTestSupport;
import org.apache.sshd.util.test.ContainerTestCase;
import org.apache.sshd.util.test.CoreTestSupportUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.Testcontainers;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.output.Slf4jLogConsumer;
import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.images.builder.ImageFromDockerfile;
import org.testcontainers.utility.MountableFile;

@RunWith(Parameterized.class)
@Category({ContainerTestCase.class})
/* loaded from: input_file:org/apache/sshd/common/forward/PortForwardingWithOpenSshTest.class */
public class PortForwardingWithOpenSshTest extends BaseTestSupport {
    private static final Logger LOG = LoggerFactory.getLogger(PortForwardingWithOpenSshTest.class);
    private static final String TEST_KEYS = "org/apache/sshd/client/opensshcerts/user";

    @Rule
    public TemporaryFolder tmp = new TemporaryFolder();
    private Server gRpc;
    private int gRpcPort;
    private SshServer sshd;
    private int sshPort;
    private CountDownLatch forwardingSetup;
    private int forwardedPort;
    private final String portToForward;

    public PortForwardingWithOpenSshTest(String str) {
        this.portToForward = str;
    }

    @Parameterized.Parameters(name = "{0}")
    public static String[] portSpecifications() {
        return new String[]{"127.0.0.1:0", "0.0.0.0:0", "0", "localhost:0"};
    }

    @Before
    public void startServers() throws Exception {
        this.gRpc = ServerBuilder.forPort(0).build();
        CountDownLatch countDownLatch = new CountDownLatch(1);
        new Thread(() -> {
            try {
                this.gRpc.start();
                this.gRpcPort = this.gRpc.getPort();
                countDownLatch.countDown();
                this.gRpc.awaitTermination();
            } catch (Exception e) {
            }
        }).start();
        countDownLatch.await();
        LOG.info("gRPC running on port {}", Integer.valueOf(this.gRpcPort));
        this.forwardingSetup = new CountDownLatch(1);
        this.sshd = CoreTestSupportUtils.setupTestServer(PortForwardingWithOpenSshTest.class);
        this.sshd.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
        this.sshd.setForwarderFactory(new DefaultForwarderFactory() { // from class: org.apache.sshd.common.forward.PortForwardingWithOpenSshTest.1
            public Forwarder create(ConnectionService connectionService) {
                DefaultForwarder defaultForwarder = new DefaultForwarder(connectionService) { // from class: org.apache.sshd.common.forward.PortForwardingWithOpenSshTest.1.1
                    public SshdSocketAddress localPortForwardingRequested(SshdSocketAddress sshdSocketAddress) throws IOException {
                        SshdSocketAddress localPortForwardingRequested = super.localPortForwardingRequested(sshdSocketAddress);
                        PortForwardingWithOpenSshTest.this.forwardedPort = localPortForwardingRequested == null ? -1 : localPortForwardingRequested.getPort();
                        PortForwardingWithOpenSshTest.this.forwardingSetup.countDown();
                        return localPortForwardingRequested;
                    }
                };
                defaultForwarder.addPortForwardingEventListenerManager(this);
                return defaultForwarder;
            }
        });
        this.sshd.setChannelFactories(Arrays.asList(ChannelSessionFactory.INSTANCE, DirectTcpipFactory.INSTANCE, ForwardedTcpipFactory.INSTANCE));
        this.sshd.start();
        this.sshPort = this.sshd.getPort();
    }

    @After
    public void teardownServers() throws Exception {
        try {
            this.gRpc.shutdownNow();
        } finally {
            this.sshd.stop();
        }
    }

    @Test
    public void forwardingWithConnectionClose() throws Exception {
        File newFile = this.tmp.newFile();
        Files.write(newFile.toPath(), ("#!/bin/sh\n\nchmod 0600 /root/.ssh/*\n/usr/bin/ssh -o 'ExitOnForwardFailure yes' -o 'StrictHostKeyChecking off' -vvv -p " + this.sshPort + " -x -N -T -R " + this.portToForward + ":host.testcontainers.internal:" + this.gRpcPort + " bob@host.testcontainers.internal\n").getBytes(StandardCharsets.US_ASCII), new OpenOption[0]);
        GenericContainer withLogConsumer = new GenericContainer(new ImageFromDockerfile().withDockerfileFromBuilder(dockerfileBuilder -> {
            dockerfileBuilder.from("alpine:3.16").run("apk --update add openssh openssh-server").run("mkdir -p /root/.ssh").entryPoint("/entrypoint.sh").build();
        })).withCopyFileToContainer(MountableFile.forClasspathResource("org/apache/sshd/client/opensshcerts/user/user01_ed25519"), "/root/.ssh/id_ed25519").withCopyFileToContainer(MountableFile.forClasspathResource("org/apache/sshd/client/opensshcerts/user/user01_ed25519.pub"), "/root/.ssh/id_ed25519.pub").withCopyFileToContainer(MountableFile.forHostPath(newFile.getPath(), 511), "/entrypoint.sh").withAccessToHost(true).waitingFor(Wait.forLogMessage(".*forwarding_success.*\n", 1)).withLogConsumer(new Slf4jLogConsumer(LOG));
        try {
            Testcontainers.exposeHostPorts(new int[]{this.sshPort, this.gRpcPort});
            withLogConsumer.start();
            this.forwardingSetup.await();
            assertTrue("Server should listen on port", this.forwardedPort > 0);
            LOG.info("sshd server listening for forwarding on port {}", Integer.valueOf(this.forwardedPort));
            ArrayList arrayList = new ArrayList();
            Socket socket = new Socket("127.0.0.1", this.forwardedPort);
            try {
                OutputStream outputStream = socket.getOutputStream();
                try {
                    InputStream inputStream = socket.getInputStream();
                    try {
                        outputStream.write("GET / HTTP 1.1\r\nConnection: keep-alive\r\nHost: 127.0.0.1\r\n\r\n".getBytes(StandardCharsets.US_ASCII));
                        byte[] bArr = new byte[1024];
                        int i = 0;
                        while (i >= 0) {
                            i = inputStream.read(bArr, 0, bArr.length);
                            if (i > 0) {
                                arrayList.add(new String(bArr, 0, i, StandardCharsets.ISO_8859_1));
                            }
                        }
                        if (inputStream != null) {
                            inputStream.close();
                        }
                        if (outputStream != null) {
                            outputStream.close();
                        }
                        socket.close();
                        assertFalse("Expected data", arrayList.isEmpty());
                        String str = (String) arrayList.get(arrayList.size() - 1);
                        assertTrue("Unexpected data: " + str, str.endsWith("HTTP/2 client preface string missing or corrupt. Hex dump for received bytes: 474554202f204854545020312e310d0a436f6e6e65637469"));
                        withLogConsumer.stop();
                    } catch (Throwable th) {
                        if (inputStream != null) {
                            try {
                                inputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    if (outputStream != null) {
                        try {
                            outputStream.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } finally {
            }
        } catch (Throwable th5) {
            withLogConsumer.stop();
            throw th5;
        }
    }
}
