package org.apache.flink.runtime.io.network;

import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemoryType;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.query.KvStateClientProxy;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.query.KvStateServer;
import org.apache.flink.runtime.taskmanager.Task;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/* loaded from: input_file:org/apache/flink/runtime/io/network/NetworkEnvironmentTest.class */
public class NetworkEnvironmentTest {
    private static final int numBuffers = 1024;
    private static final int memorySegmentSize = 128;

    @Test
    public void testRegisterTaskUsesBoundedBuffers() throws Exception {
        NetworkEnvironment networkEnvironment = new NetworkEnvironment(new NetworkBufferPool(numBuffers, memorySegmentSize, MemoryType.HEAP), new LocalConnectionManager(), new ResultPartitionManager(), new TaskEventDispatcher(), new KvStateRegistry(), (KvStateServer) null, (KvStateClientProxy) null, IOManager.IOMode.SYNC, 0, 0, 2, 8);
        ResultPartition createResultPartition = createResultPartition(ResultPartitionType.PIPELINED, 2);
        ResultPartition createResultPartition2 = createResultPartition(ResultPartitionType.BLOCKING, 2);
        ResultPartition createResultPartition3 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 2);
        ResultPartition createResultPartition4 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 8);
        ResultPartition[] resultPartitionArr = {createResultPartition, createResultPartition2, createResultPartition3, createResultPartition4};
        ResultPartitionWriter[] resultPartitionWriterArr = {new ResultPartitionWriter(createResultPartition), new ResultPartitionWriter(createResultPartition2), new ResultPartitionWriter(createResultPartition3), new ResultPartitionWriter(createResultPartition4)};
        SingleInputGate createSingleInputGateMock = createSingleInputGateMock(ResultPartitionType.PIPELINED, 2);
        SingleInputGate createSingleInputGateMock2 = createSingleInputGateMock(ResultPartitionType.BLOCKING, 2);
        SingleInputGate createSingleInputGateMock3 = createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 2);
        SingleInputGate createSingleInputGateMock4 = createSingleInputGateMock(ResultPartitionType.PIPELINED_CREDIT_BASED, 8);
        SingleInputGate[] singleInputGateArr = {createSingleInputGateMock, createSingleInputGateMock2, createSingleInputGateMock3, createSingleInputGateMock4};
        Task task = (Task) Mockito.mock(Task.class);
        Mockito.when(task.getProducedPartitions()).thenReturn(resultPartitionArr);
        Mockito.when(task.getAllWriters()).thenReturn(resultPartitionWriterArr);
        Mockito.when(task.getAllInputGates()).thenReturn(singleInputGateArr);
        networkEnvironment.registerTask(task);
        Assert.assertEquals(2147483647L, createResultPartition.getBufferPool().getMaxNumberOfMemorySegments());
        Assert.assertEquals(2147483647L, createResultPartition2.getBufferPool().getMaxNumberOfMemorySegments());
        Assert.assertEquals(12L, createResultPartition3.getBufferPool().getMaxNumberOfMemorySegments());
        Assert.assertEquals(24L, createResultPartition4.getBufferPool().getMaxNumberOfMemorySegments());
        ((SingleInputGate) Mockito.verify(createSingleInputGateMock4, Mockito.times(1))).assignExclusiveSegments(networkEnvironment.getNetworkBufferPool(), 2);
        networkEnvironment.shutdown();
    }

    private static ResultPartition createResultPartition(ResultPartitionType resultPartitionType, int i) {
        return new ResultPartition("TestTask-" + resultPartitionType + ":" + i, (TaskActions) Mockito.mock(TaskActions.class), new JobID(), new ResultPartitionID(), resultPartitionType, i, i, (ResultPartitionManager) Mockito.mock(ResultPartitionManager.class), (ResultPartitionConsumableNotifier) Mockito.mock(ResultPartitionConsumableNotifier.class), (IOManager) Mockito.mock(IOManager.class), false);
    }

    private static SingleInputGate createSingleInputGateMock(final ResultPartitionType resultPartitionType, final int i) {
        SingleInputGate singleInputGate = (SingleInputGate) Mockito.mock(SingleInputGate.class);
        Mockito.when(singleInputGate.getConsumedPartitionType()).thenReturn(resultPartitionType);
        Mockito.when(Integer.valueOf(singleInputGate.getNumberOfInputChannels())).thenReturn(Integer.valueOf(i));
        ((SingleInputGate) Mockito.doAnswer(new Answer<Void>() { // from class: org.apache.flink.runtime.io.network.NetworkEnvironmentTest.1
            /* renamed from: answer, reason: merged with bridge method [inline-methods] */
            public Void m80answer(InvocationOnMock invocationOnMock) throws Throwable {
                BufferPool bufferPool = (BufferPool) invocationOnMock.getArgumentAt(0, BufferPool.class);
                if (resultPartitionType == ResultPartitionType.PIPELINED_BOUNDED) {
                    Assert.assertEquals((i * 2) + 8, bufferPool.getMaxNumberOfMemorySegments());
                    return null;
                }
                if (resultPartitionType == ResultPartitionType.PIPELINED_CREDIT_BASED) {
                    Assert.assertEquals(8L, bufferPool.getMaxNumberOfMemorySegments());
                    return null;
                }
                Assert.assertEquals(2147483647L, bufferPool.getMaxNumberOfMemorySegments());
                return null;
            }
        }).when(singleInputGate)).setBufferPool((BufferPool) Matchers.any(BufferPool.class));
        return singleInputGate;
    }
}
