package org.apache.kafka.common.network;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.apache.kafka.common.MetricName;
import org.apache.kafka.common.metrics.KafkaMetric;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.network.ChannelState;
import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.TestSecurityConfig;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.authenticator.LoginManager;
import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.internals.ScramFormatter;
import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.server.interceptor.BrokerInterceptor;
import org.apache.kafka.test.TestUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/kafka/common/network/ReverseConnectionTest.class */
public class ReverseConnectionTest {
    private static final int BUFFER_SIZE = 4096;
    private final Time clientTime = new MockTime();
    private final Time serverTime = new MockTime();
    private final Semaphore pollSemaphore = new Semaphore(1);
    private final Metrics clientMetrics = new Metrics();
    private NioEchoServer server;
    private Selector selector;
    private Map<String, Object> saslClientConfigs;
    private Map<String, Object> saslServerConfigs;
    private CredentialCache credentialCache;

    @BeforeEach
    public void setup() throws Exception {
        LoginManager.closeAll();
        CertStores certStores = new CertStores(true, "localhost");
        CertStores certStores2 = new CertStores(false, "localhost");
        this.saslServerConfigs = certStores.getTrustingConfig(certStores2);
        this.saslClientConfigs = certStores2.getTrustingConfig(certStores);
        this.saslServerConfigs.put("ssl.engine.factory.class", DefaultSslEngineFactory.class);
        this.saslClientConfigs.put("ssl.engine.factory.class", DefaultSslEngineFactory.class);
        this.credentialCache = new CredentialCache();
    }

    @AfterEach
    public void teardown() throws Exception {
        if (this.server != null) {
            this.server.close();
        }
        if (this.selector != null) {
            this.selector.close();
        }
    }

    @Test
    public void testReverseConnectionSaslPlain() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
        this.server = createEchoServer(securityProtocol);
        createAndVerifyConnection(securityProtocol, KafkaChannelTest.CHANNEL_ID);
        reverseAndVerifyConnection(KafkaChannelTest.CHANNEL_ID, "1", true);
    }

    @Test
    public void testReverseConnectionSaslScram() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256"));
        this.server = createEchoServer(securityProtocol);
        updateScramCredentialCache("SCRAM-SHA-256", TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD);
        createAndVerifyConnection(securityProtocol, KafkaChannelTest.CHANNEL_ID);
        reverseAndVerifyConnection(KafkaChannelTest.CHANNEL_ID, "1", true);
    }

    @Test
    public void testReverseConnectionSsl() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = createEchoServer(securityProtocol);
        createAndVerifyConnection(securityProtocol, KafkaChannelTest.CHANNEL_ID);
        reverseAndVerifyConnection(KafkaChannelTest.CHANNEL_ID, "1", true);
    }

    @Test
    public void testReverseConnectionSslWithBufferedRead() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = createEchoServer(securityProtocol);
        createAndVerifyConnection(securityProtocol, KafkaChannelTest.CHANNEL_ID);
        this.pollSemaphore.acquire();
        Selector selector = this.server.selector();
        Selector selector2 = this.selector;
        KafkaChannel kafkaChannel = (KafkaChannel) selector.channels().get(0);
        selector.removeChannelWithoutClosing(kafkaChannel);
        KafkaChannel channel = this.selector.channel(KafkaChannelTest.CHANNEL_ID);
        selector2.removeChannelWithoutClosing(channel);
        SslTransportLayer sslTransportLayer = (SslTransportLayer) TestUtils.fieldValue(channel, KafkaChannel.class, "transportLayer");
        ByteBuffer byteBuffer = (ByteBuffer) TestUtils.fieldValue(sslTransportLayer, SslTransportLayer.class, "appReadBuffer");
        Assertions.assertEquals(0, byteBuffer.position());
        byte[] bytes = "testMessage".getBytes(StandardCharsets.UTF_8);
        byteBuffer.putInt(bytes.length);
        byteBuffer.put(bytes);
        TestUtils.setFieldValue(sslTransportLayer, "hasBytesBuffered", true);
        KafkaPrincipal principal = kafkaChannel.principal();
        selector2.addReverseChannel(kafkaChannel.reverse("1", (BrokerInterceptor) null, (KafkaPrincipal) null, kafkaChannel2 -> {
        }));
        selector.addReverseChannel(channel.reverse(kafkaChannel.id(), (BrokerInterceptor) null, principal, kafkaChannel3 -> {
        }));
        this.pollSemaphore.release();
        TestUtils.waitForCondition(() -> {
            selector2.poll(1L);
            if (selector2.completedReceives().isEmpty()) {
                return false;
            }
            Assertions.assertEquals(1, selector2.completedReceives().size());
            Assertions.assertEquals("testMessage", new String(Utils.toArray(((NetworkReceive) selector2.completedReceives().iterator().next()).payload()), StandardCharsets.UTF_8));
            return true;
        }, "Buffered receive not processed");
    }

    @Test
    public void testReverseBeforeAuthentication() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256"));
        this.server = createEchoServer(securityProtocol);
        updateScramCredentialCache("SCRAM-SHA-256", TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD);
        createSelector(securityProtocol, this.saslClientConfigs);
        this.selector.connect(KafkaChannelTest.CHANNEL_ID, new InetSocketAddress("localhost", this.server.port()), BUFFER_SIZE, BUFFER_SIZE);
        while (this.server.selector().channels().isEmpty()) {
            this.selector.poll(100L);
        }
        KafkaChannel kafkaChannel = (KafkaChannel) this.server.selector().channels().get(0);
        KafkaChannel kafkaChannel2 = (KafkaChannel) this.selector.channels().get(0);
        Assertions.assertFalse(kafkaChannel.ready());
        Assertions.assertFalse(kafkaChannel2.ready());
        KafkaPrincipal kafkaPrincipal = new KafkaPrincipal("User", "someuser");
        Assertions.assertThrows(IllegalStateException.class, () -> {
            kafkaChannel.reverse("1", (BrokerInterceptor) null, (KafkaPrincipal) null, (Consumer) null);
        });
        Assertions.assertThrows(IllegalStateException.class, () -> {
            kafkaChannel2.reverse(kafkaChannel.id(), (BrokerInterceptor) null, kafkaPrincipal, (Consumer) null);
        });
        createAndVerifyConnection(securityProtocol, KafkaChannelTest.CHANNEL_ID);
        Assertions.assertThrows(IllegalStateException.class, () -> {
            kafkaChannel.reverse("1", (BrokerInterceptor) null, (KafkaPrincipal) null, (Consumer) null);
        });
        Assertions.assertThrows(IllegalStateException.class, () -> {
            kafkaChannel2.reverse(kafkaChannel.id(), (BrokerInterceptor) null, kafkaPrincipal, (Consumer) null);
        });
    }

    @Test
    public void testReverseFailsIfAuthenticationFails() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256"));
        this.server = createEchoServer(securityProtocol);
        createSelector(securityProtocol, this.saslClientConfigs);
        this.selector.connect(KafkaChannelTest.CHANNEL_ID, new InetSocketAddress("localhost", this.server.port()), BUFFER_SIZE, BUFFER_SIZE);
        while (this.server.selector().channels().isEmpty()) {
            this.selector.poll(100L);
        }
        KafkaChannel kafkaChannel = (KafkaChannel) this.server.selector().channels().get(0);
        KafkaChannel kafkaChannel2 = (KafkaChannel) this.selector.channels().get(0);
        NetworkTestUtils.waitForChannelClose(this.selector, KafkaChannelTest.CHANNEL_ID, ChannelState.State.AUTHENTICATION_FAILED);
        KafkaPrincipal kafkaPrincipal = new KafkaPrincipal("User", "someuser");
        Assertions.assertThrows(IllegalStateException.class, () -> {
            kafkaChannel.reverse("1", (BrokerInterceptor) null, (KafkaPrincipal) null, (Consumer) null);
        });
        Assertions.assertThrows(IllegalStateException.class, () -> {
            kafkaChannel2.reverse(kafkaChannel.id(), (BrokerInterceptor) null, kafkaPrincipal, (Consumer) null);
        });
    }

    @Test
    public void testClientIdleExpiry() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = createEchoServer(securityProtocol);
        createAndVerifyConnection(securityProtocol, KafkaChannelTest.CHANNEL_ID);
        reverseAndVerifyConnection(KafkaChannelTest.CHANNEL_ID, "10", false);
        Assertions.assertEquals(1, this.server.selector().channels().size());
        Assertions.assertEquals(1, this.selector.channels().size());
        this.clientTime.sleep(TimeUnit.MINUTES.toMillis(10L));
        this.selector.poll(1L);
        Assertions.assertEquals(Collections.emptyList(), this.selector.channels());
        Assertions.assertEquals(ChannelState.State.EXPIRED, ((ChannelState) this.selector.disconnected().get("10")).state());
        TestUtils.waitForCondition(() -> {
            return this.server.selector().channels().isEmpty();
        }, "Server channel not disconnected");
        SelectorTest.verifySelectorEmpty(this.selector);
        SelectorTest.verifySelectorEmpty(this.server.selector());
    }

    @Test
    public void testServerIdleExpiry() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = createEchoServer(securityProtocol);
        createAndVerifyConnection(securityProtocol, KafkaChannelTest.CHANNEL_ID);
        reverseAndVerifyConnection(KafkaChannelTest.CHANNEL_ID, "10", false);
        Assertions.assertEquals(1, this.server.selector().channels().size());
        Assertions.assertEquals(1, this.selector.channels().size());
        this.serverTime.sleep(TimeUnit.MINUTES.toMillis(10L));
        TestUtils.waitForCondition(() -> {
            return this.server.selector().channels().isEmpty();
        }, "Server channel not expired");
        NetworkTestUtils.waitForChannelClose(this.selector, "10", ChannelState.State.READY);
        SelectorTest.verifySelectorEmpty(this.selector);
        SelectorTest.verifySelectorEmpty(this.server.selector());
    }

    private void configureMechanisms(String str, List<String> list) {
        this.saslClientConfigs.put("sasl.mechanism", str);
        this.saslServerConfigs.put("sasl.enabled.mechanisms", list);
        TestJaasConfig.createConfiguration(str, list);
    }

    private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception {
        NioEchoServer nioEchoServer = new NioEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol, new TestSecurityConfig(this.saslServerConfigs), "localhost", null, this.credentialCache, 0, this.serverTime) { // from class: org.apache.kafka.common.network.ReverseConnectionTest.1
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.kafka.common.network.NioEchoServer
            public void poll() throws IOException {
                try {
                    ReverseConnectionTest.this.pollSemaphore.acquire();
                    super.poll();
                    ReverseConnectionTest.this.pollSemaphore.release();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        };
        nioEchoServer.start();
        return nioEchoServer;
    }

    private void updateScramCredentialCache(String str, String str2, String str3) throws NoSuchAlgorithmException {
        ScramMechanism forMechanismName = ScramMechanism.forMechanismName(str);
        this.credentialCache.cache(forMechanismName.mechanismName(), ScramCredential.class).put(str2, new ScramFormatter(forMechanismName).generateCredential(str3, BUFFER_SIZE));
    }

    private void createSelector(SecurityProtocol securityProtocol, Map<String, Object> map) {
        if (this.selector != null) {
            this.selector.close();
            this.selector = null;
        }
        this.selector = new Selector(5000L, this.clientMetrics, this.clientTime, "MetricGroup", ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, new TestSecurityConfig(map), (ListenerName) null, (String) this.saslClientConfigs.get("sasl.mechanism"), this.clientTime, true, new LogContext()), new LogContext());
    }

    private void createAndVerifyConnection(SecurityProtocol securityProtocol, String str) throws Exception {
        createSelector(securityProtocol, this.saslClientConfigs);
        this.selector.connect(str, new InetSocketAddress("localhost", this.server.port()), BUFFER_SIZE, BUFFER_SIZE);
        NetworkTestUtils.checkClientConnection(this.selector, str, 100, 10);
    }

    private void reverseAndVerifyConnection(String str, String str2, boolean z) throws Exception {
        verifyMetric("reverse-connection-added", 0.0d);
        verifyMetric("reverse-connection-removed", 0.0d);
        this.pollSemaphore.acquire();
        Selector selector = this.server.selector();
        Assertions.assertEquals(1, selector.channels().size());
        Selector selector2 = this.selector;
        Assertions.assertEquals(1, selector2.channels().size());
        AtomicInteger atomicInteger = new AtomicInteger();
        AtomicInteger atomicInteger2 = new AtomicInteger();
        KafkaChannel kafkaChannel = (KafkaChannel) selector.channels().get(0);
        selector.removeChannelWithoutClosing(kafkaChannel);
        KafkaChannel channel = this.selector.channel(str);
        selector2.removeChannelWithoutClosing(channel);
        SelectorTest.verifySelectorEmpty(selector);
        SelectorTest.verifySelectorEmpty(selector2);
        KafkaPrincipal principal = kafkaChannel.principal();
        KafkaChannel reverse = kafkaChannel.reverse(str2, (BrokerInterceptor) null, (KafkaPrincipal) null, kafkaChannel2 -> {
            atomicInteger.incrementAndGet();
        });
        selector2.addReverseChannel(reverse);
        selector.addReverseChannel(channel.reverse(kafkaChannel.id(), (BrokerInterceptor) null, principal, kafkaChannel3 -> {
            atomicInteger2.incrementAndGet();
        }));
        Assertions.assertEquals(1, selector.channels().size());
        Assertions.assertEquals(1, selector2.channels().size());
        Assertions.assertEquals(principal, ((KafkaChannel) selector.channels().get(0)).principal());
        this.pollSemaphore.release();
        NetworkTestUtils.checkClientConnection(this.selector, str2, 100, 10);
        verifyMetric("reverse-connection-added", 1.0d);
        verifyMetric("reverse-connection-removed", 1.0d);
        if (z) {
            selector2.close(reverse.id());
            Assertions.assertEquals(1, atomicInteger.get());
            TestUtils.waitForCondition(() -> {
                return atomicInteger2.get() == 1;
            }, "Close listener not invoked");
        }
    }

    private void verifyMetric(String str, double d) {
        Optional findFirst = this.clientMetrics.metrics().entrySet().stream().filter(entry -> {
            return ((MetricName) entry.getKey()).name().equals(str + "-total");
        }).map((v0) -> {
            return v0.getValue();
        }).findFirst();
        Assertions.assertTrue(findFirst.isPresent(), "Metric not found " + str);
        Assertions.assertEquals(d, ((Double) ((KafkaMetric) findFirst.get()).metricValue()).doubleValue(), 0.001d);
    }
}
