package org.apache.spark.network.sasl;

import com.google.common.collect.Lists;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/spark/network/sasl/SaslIntegrationSuite.class */
public class SaslIntegrationSuite {
    private static final long TIMEOUT_MS = 10000;
    static TransportServer server;
    static TransportConf conf;
    static TransportContext context;
    static SecretKeyHolder secretKeyHolder;
    TransportClientFactory clientFactory;

    /* loaded from: input_file:org/apache/spark/network/sasl/SaslIntegrationSuite$TestRpcHandler.class */
    public static class TestRpcHandler extends RpcHandler {
        public void receive(TransportClient transportClient, ByteBuffer byteBuffer, RpcResponseCallback rpcResponseCallback) {
            rpcResponseCallback.onSuccess(byteBuffer);
        }

        public StreamManager getStreamManager() {
            return new OneForOneStreamManager();
        }
    }

    @BeforeClass
    public static void beforeAll() throws IOException {
        conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
        context = new TransportContext(conf, new TestRpcHandler());
        secretKeyHolder = (SecretKeyHolder) Mockito.mock(SecretKeyHolder.class);
        Mockito.when(secretKeyHolder.getSaslUser((String) Mockito.eq("app-1"))).thenReturn("app-1");
        Mockito.when(secretKeyHolder.getSecretKey((String) Mockito.eq("app-1"))).thenReturn("app-1");
        Mockito.when(secretKeyHolder.getSaslUser((String) Mockito.eq("app-2"))).thenReturn("app-2");
        Mockito.when(secretKeyHolder.getSecretKey((String) Mockito.eq("app-2"))).thenReturn("app-2");
        Mockito.when(secretKeyHolder.getSaslUser(Mockito.anyString())).thenReturn("other-app");
        Mockito.when(secretKeyHolder.getSecretKey(Mockito.anyString())).thenReturn("correct-password");
        server = context.createServer(Arrays.asList(new SaslServerBootstrap(conf, secretKeyHolder)));
    }

    @AfterClass
    public static void afterAll() {
        server.close();
    }

    @After
    public void afterEach() {
        if (this.clientFactory != null) {
            this.clientFactory.close();
            this.clientFactory = null;
        }
    }

    @Test
    public void testGoodClient() throws IOException {
        this.clientFactory = context.createClientFactory(Lists.newArrayList(new TransportClientBootstrap[]{new SaslClientBootstrap(conf, "app-1", secretKeyHolder)}));
        Assert.assertEquals("Hello, World!", JavaUtils.bytesToString(this.clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()).sendRpcSync(JavaUtils.stringToBytes("Hello, World!"), TIMEOUT_MS)));
    }

    @Test
    public void testBadClient() {
        SecretKeyHolder secretKeyHolder2 = (SecretKeyHolder) Mockito.mock(SecretKeyHolder.class);
        Mockito.when(secretKeyHolder2.getSaslUser(Mockito.anyString())).thenReturn("other-app");
        Mockito.when(secretKeyHolder2.getSecretKey(Mockito.anyString())).thenReturn("wrong-password");
        this.clientFactory = context.createClientFactory(Lists.newArrayList(new TransportClientBootstrap[]{new SaslClientBootstrap(conf, "unknown-app", secretKeyHolder2)}));
        try {
            this.clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
            Assert.fail("Connection should have failed.");
        } catch (Exception e) {
            Assert.assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
        }
    }

    @Test
    public void testNoSaslClient() throws IOException {
        this.clientFactory = context.createClientFactory(Lists.newArrayList());
        TransportClient createClient = this.clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
        try {
            createClient.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
            Assert.fail("Should have failed");
        } catch (Exception e) {
            Assert.assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
        }
        try {
            createClient.sendRpcSync(ByteBuffer.wrap(new byte[]{-22}), TIMEOUT_MS);
            Assert.fail("Should have failed");
        } catch (Exception e2) {
            Assert.assertTrue(e2.getMessage(), e2.getMessage().contains("java.lang.IndexOutOfBoundsException"));
        }
    }

    @Test
    public void testNoSaslServer() {
        TransportContext transportContext = new TransportContext(conf, new TestRpcHandler());
        this.clientFactory = transportContext.createClientFactory(Lists.newArrayList(new TransportClientBootstrap[]{new SaslClientBootstrap(conf, "app-1", secretKeyHolder)}));
        TransportServer createServer = transportContext.createServer();
        try {
            try {
                this.clientFactory.createClient(TestUtils.getLocalHost(), createServer.getPort());
                createServer.close();
            } catch (Exception e) {
                Assert.assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
                createServer.close();
            }
        } catch (Throwable th) {
            createServer.close();
            throw th;
        }
    }

    @Test
    public void testAppIsolation() throws Exception {
        ExternalShuffleBlockHandler externalShuffleBlockHandler = new ExternalShuffleBlockHandler(new OneForOneStreamManager(), (ExternalShuffleBlockResolver) Mockito.mock(ExternalShuffleBlockResolver.class));
        TransportServerBootstrap saslServerBootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
        TransportContext transportContext = new TransportContext(conf, externalShuffleBlockHandler);
        TransportServer createServer = transportContext.createServer(Arrays.asList(saslServerBootstrap));
        TransportClient transportClient = null;
        TransportClient transportClient2 = null;
        TransportClientFactory transportClientFactory = null;
        try {
            this.clientFactory = transportContext.createClientFactory(Lists.newArrayList(new TransportClientBootstrap[]{new SaslClientBootstrap(conf, "app-1", secretKeyHolder)}));
            transportClient = this.clientFactory.createClient(TestUtils.getLocalHost(), createServer.getPort());
            final AtomicReference atomicReference = new AtomicReference();
            BlockFetchingListener blockFetchingListener = new BlockFetchingListener() { // from class: org.apache.spark.network.sasl.SaslIntegrationSuite.1
                public synchronized void onBlockFetchSuccess(String str, ManagedBuffer managedBuffer) {
                    notifyAll();
                }

                public synchronized void onBlockFetchFailure(String str, Throwable th) {
                    atomicReference.set(th);
                    notifyAll();
                }
            };
            String[] strArr = {"shuffle_2_3_4", "shuffle_6_7_8"};
            OneForOneBlockFetcher oneForOneBlockFetcher = new OneForOneBlockFetcher(transportClient, "app-2", "0", strArr, blockFetchingListener);
            synchronized (blockFetchingListener) {
                oneForOneBlockFetcher.start();
                blockFetchingListener.wait();
            }
            checkSecurityException((Throwable) atomicReference.get());
            transportClient.sendRpcSync(new RegisterExecutor("app-1", "0", new ExecutorShuffleInfo(new String[]{System.getProperty("java.io.tmpdir")}, 1, "org.apache.spark.shuffle.sort.SortShuffleManager")).toByteBuffer(), TIMEOUT_MS);
            long j = BlockTransferMessage.Decoder.fromByteBuffer(transportClient.sendRpcSync(new OpenBlocks("app-1", "0", strArr).toByteBuffer(), TIMEOUT_MS)).streamId;
            transportClientFactory = transportContext.createClientFactory(Lists.newArrayList(new TransportClientBootstrap[]{new SaslClientBootstrap(conf, "app-2", secretKeyHolder)}));
            transportClient2 = transportClientFactory.createClient(TestUtils.getLocalHost(), createServer.getPort());
            ChunkReceivedCallback chunkReceivedCallback = new ChunkReceivedCallback() { // from class: org.apache.spark.network.sasl.SaslIntegrationSuite.2
                public synchronized void onSuccess(int i, ManagedBuffer managedBuffer) {
                    notifyAll();
                }

                public synchronized void onFailure(int i, Throwable th) {
                    atomicReference.set(th);
                    notifyAll();
                }
            };
            atomicReference.set(null);
            synchronized (chunkReceivedCallback) {
                transportClient2.fetchChunk(j, 0, chunkReceivedCallback);
                chunkReceivedCallback.wait();
            }
            checkSecurityException((Throwable) atomicReference.get());
            if (transportClient != null) {
                transportClient.close();
            }
            if (transportClient2 != null) {
                transportClient2.close();
            }
            if (transportClientFactory != null) {
                transportClientFactory.close();
            }
            createServer.close();
        } catch (Throwable th) {
            if (transportClient != null) {
                transportClient.close();
            }
            if (transportClient2 != null) {
                transportClient2.close();
            }
            if (transportClientFactory != null) {
                transportClientFactory.close();
            }
            createServer.close();
            throw th;
        }
    }

    private void checkSecurityException(Throwable th) {
        Assert.assertNotNull("No exception was caught.", th);
        Assert.assertTrue("Expected SecurityException.", th.getMessage().contains(SecurityException.class.getName()));
    }
}
