package org.apache.spark.network.shuffle;

import com.google.common.collect.Maps;
import io.netty.buffer.Unpooled;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/* loaded from: input_file:org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.class */
public class OneForOneBlockFetcherSuite {
    @Test
    public void testFetchOne() {
        LinkedHashMap<String, ManagedBuffer> newLinkedHashMap = Maps.newLinkedHashMap();
        newLinkedHashMap.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
        ((BlockFetchingListener) Mockito.verify(fetchBlocks(newLinkedHashMap))).onBlockFetchSuccess("shuffle_0_0_0", newLinkedHashMap.get("shuffle_0_0_0"));
    }

    @Test
    public void testFetchThree() {
        LinkedHashMap<String, ManagedBuffer> newLinkedHashMap = Maps.newLinkedHashMap();
        newLinkedHashMap.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
        newLinkedHashMap.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
        newLinkedHashMap.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
        BlockFetchingListener fetchBlocks = fetchBlocks(newLinkedHashMap);
        for (int i = 0; i < 3; i++) {
            ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(1))).onBlockFetchSuccess("b" + i, newLinkedHashMap.get("b" + i));
        }
    }

    @Test
    public void testFailure() {
        LinkedHashMap<String, ManagedBuffer> newLinkedHashMap = Maps.newLinkedHashMap();
        newLinkedHashMap.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
        newLinkedHashMap.put("b1", null);
        newLinkedHashMap.put("b2", null);
        BlockFetchingListener fetchBlocks = fetchBlocks(newLinkedHashMap);
        ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(1))).onBlockFetchSuccess("b0", newLinkedHashMap.get("b0"));
        ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(1))).onBlockFetchFailure((String) Matchers.eq("b1"), (Throwable) Matchers.any());
        ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(2))).onBlockFetchFailure((String) Matchers.eq("b2"), (Throwable) Matchers.any());
    }

    @Test
    public void testFailureAndSuccess() {
        LinkedHashMap<String, ManagedBuffer> newLinkedHashMap = Maps.newLinkedHashMap();
        newLinkedHashMap.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
        newLinkedHashMap.put("b1", null);
        newLinkedHashMap.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));
        BlockFetchingListener fetchBlocks = fetchBlocks(newLinkedHashMap);
        ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(1))).onBlockFetchSuccess("b0", newLinkedHashMap.get("b0"));
        ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(1))).onBlockFetchFailure((String) Matchers.eq("b1"), (Throwable) Matchers.any());
        ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(1))).onBlockFetchSuccess("b2", newLinkedHashMap.get("b2"));
        ((BlockFetchingListener) Mockito.verify(fetchBlocks, Mockito.times(1))).onBlockFetchFailure((String) Matchers.eq("b2"), (Throwable) Matchers.any());
    }

    @Test
    public void testEmptyBlockFetch() {
        try {
            fetchBlocks(Maps.newLinkedHashMap());
            Assert.fail();
        } catch (IllegalArgumentException e) {
            Assert.assertEquals("Zero-sized blockIds array", e.getMessage());
        }
    }

    private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> linkedHashMap) {
        TransportClient transportClient = (TransportClient) Mockito.mock(TransportClient.class);
        BlockFetchingListener blockFetchingListener = (BlockFetchingListener) Mockito.mock(BlockFetchingListener.class);
        final String[] strArr = (String[]) linkedHashMap.keySet().toArray(new String[linkedHashMap.size()]);
        OneForOneBlockFetcher oneForOneBlockFetcher = new OneForOneBlockFetcher(transportClient, "app-id", "exec-id", strArr, blockFetchingListener);
        ((TransportClient) Mockito.doAnswer(new Answer<Void>() { // from class: org.apache.spark.network.shuffle.OneForOneBlockFetcherSuite.1
            /* renamed from: answer, reason: merged with bridge method [inline-methods] */
            public Void m2answer(InvocationOnMock invocationOnMock) throws Throwable {
                BlockTransferMessage fromByteBuffer = BlockTransferMessage.Decoder.fromByteBuffer((ByteBuffer) invocationOnMock.getArguments()[0]);
                ((RpcResponseCallback) invocationOnMock.getArguments()[1]).onSuccess(new StreamHandle(123L, linkedHashMap.size()).toByteBuffer());
                Assert.assertEquals(new OpenBlocks("app-id", "exec-id", strArr), fromByteBuffer);
                return null;
            }
        }).when(transportClient)).sendRpc((ByteBuffer) Matchers.any(ByteBuffer.class), (RpcResponseCallback) Matchers.any(RpcResponseCallback.class));
        final AtomicInteger atomicInteger = new AtomicInteger(0);
        final Iterator<ManagedBuffer> it = linkedHashMap.values().iterator();
        ((TransportClient) Mockito.doAnswer(new Answer<Void>() { // from class: org.apache.spark.network.shuffle.OneForOneBlockFetcherSuite.2
            /* renamed from: answer, reason: merged with bridge method [inline-methods] */
            public Void m3answer(InvocationOnMock invocationOnMock) throws Throwable {
                try {
                    long longValue = ((Long) invocationOnMock.getArguments()[0]).longValue();
                    int intValue = ((Integer) invocationOnMock.getArguments()[1]).intValue();
                    Assert.assertEquals(123L, longValue);
                    Assert.assertEquals(atomicInteger.getAndIncrement(), intValue);
                    ChunkReceivedCallback chunkReceivedCallback = (ChunkReceivedCallback) invocationOnMock.getArguments()[2];
                    ManagedBuffer managedBuffer = (ManagedBuffer) it.next();
                    if (managedBuffer != null) {
                        chunkReceivedCallback.onSuccess(intValue, managedBuffer);
                    } else {
                        chunkReceivedCallback.onFailure(intValue, new RuntimeException("Failed " + intValue));
                    }
                    return null;
                } catch (Exception e) {
                    e.printStackTrace();
                    Assert.fail("Unexpected failure");
                    return null;
                }
            }
        }).when(transportClient)).fetchChunk(Matchers.anyLong(), Matchers.anyInt(), (ChunkReceivedCallback) Matchers.any());
        oneForOneBlockFetcher.start();
        return blockFetchingListener;
    }
}
