package org.apache.flink.runtime.io.network.partition;

import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.class */
public class InputGateFairnessTest {

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateFairnessTest$FairnessVerifyingInputGate.class */
    private static class FairnessVerifyingInputGate extends SingleInputGate {
        private final ArrayDeque<InputChannel> channelsWithData;
        private final HashSet<InputChannel> uniquenessChecker;

        public FairnessVerifyingInputGate(String str, JobID jobID, IntermediateDataSetID intermediateDataSetID, int i, int i2, TaskActions taskActions, TaskIOMetricGroup taskIOMetricGroup) {
            super(str, jobID, intermediateDataSetID, ResultPartitionType.PIPELINED, i, i2, taskActions, taskIOMetricGroup);
            try {
                Field declaredField = SingleInputGate.class.getDeclaredField("inputChannelsWithData");
                declaredField.setAccessible(true);
                this.channelsWithData = (ArrayDeque) declaredField.get(this);
                this.uniquenessChecker = new HashSet<>();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException {
            synchronized (this.channelsWithData) {
                Assert.assertTrue("too many input channels", this.channelsWithData.size() <= getNumberOfInputChannels());
                ensureUnique(this.channelsWithData);
            }
            return super.getNextBufferOrEvent();
        }

        private void ensureUnique(Collection<InputChannel> collection) {
            HashSet<InputChannel> hashSet = this.uniquenessChecker;
            for (InputChannel inputChannel : collection) {
                if (!hashSet.add(inputChannel)) {
                    Assert.fail("Duplicate channel in input gate: " + inputChannel);
                }
            }
            Assert.assertTrue("found duplicate input channels", hashSet.size() == collection.size());
            hashSet.clear();
        }
    }

    @Test
    public void testFairConsumptionLocalChannelsPreFilled() throws Exception {
        ResultPartition resultPartition = (ResultPartition) Mockito.mock(ResultPartition.class);
        Buffer createMockBuffer = InputChannelTestUtils.createMockBuffer(42);
        PipelinedSubpartition[] pipelinedSubpartitionArr = new PipelinedSubpartition[37];
        for (int i = 0; i < 37; i++) {
            PipelinedSubpartition pipelinedSubpartition = new PipelinedSubpartition(0, resultPartition);
            for (int i2 = 0; i2 < 27; i2++) {
                pipelinedSubpartition.add(createMockBuffer);
            }
            pipelinedSubpartition.finish();
            pipelinedSubpartitionArr[i] = pipelinedSubpartition;
        }
        ResultPartitionManager createResultPartitionManager = InputChannelTestUtils.createResultPartitionManager(pipelinedSubpartitionArr);
        FairnessVerifyingInputGate fairnessVerifyingInputGate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions) Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        for (int i3 = 0; i3 < 37; i3++) {
            fairnessVerifyingInputGate.setInputChannel(new IntermediateResultPartitionID(), new LocalInputChannel(fairnessVerifyingInputGate, i3, new ResultPartitionID(), createResultPartitionManager, (TaskEventDispatcher) Mockito.mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()));
        }
        for (int i4 = 1036; i4 > 0; i4--) {
            Assert.assertNotNull(fairnessVerifyingInputGate.getNextBufferOrEvent());
            int i5 = Integer.MAX_VALUE;
            int i6 = 0;
            for (PipelinedSubpartition pipelinedSubpartition2 : pipelinedSubpartitionArr) {
                int currentNumberOfBuffers = pipelinedSubpartition2.getCurrentNumberOfBuffers();
                i5 = Math.min(i5, currentNumberOfBuffers);
                i6 = Math.max(i6, currentNumberOfBuffers);
            }
            Assert.assertTrue(i6 == i5 || i6 == i5 + 1);
        }
        Assert.assertNull(fairnessVerifyingInputGate.getNextBufferOrEvent());
    }

    @Test
    public void testFairConsumptionLocalChannels() throws Exception {
        ResultPartition resultPartition = (ResultPartition) Mockito.mock(ResultPartition.class);
        Buffer createMockBuffer = InputChannelTestUtils.createMockBuffer(42);
        PipelinedSubpartition[] pipelinedSubpartitionArr = new PipelinedSubpartition[37];
        for (int i = 0; i < 37; i++) {
            pipelinedSubpartitionArr[i] = new PipelinedSubpartition(0, resultPartition);
        }
        ResultPartitionManager createResultPartitionManager = InputChannelTestUtils.createResultPartitionManager(pipelinedSubpartitionArr);
        FairnessVerifyingInputGate fairnessVerifyingInputGate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions) Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        for (int i2 = 0; i2 < 37; i2++) {
            fairnessVerifyingInputGate.setInputChannel(new IntermediateResultPartitionID(), new LocalInputChannel(fairnessVerifyingInputGate, i2, new ResultPartitionID(), createResultPartitionManager, (TaskEventDispatcher) Mockito.mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()));
        }
        pipelinedSubpartitionArr[12].add(createMockBuffer);
        for (int i3 = 0; i3 < 999; i3++) {
            Assert.assertNotNull(fairnessVerifyingInputGate.getNextBufferOrEvent());
            int i4 = Integer.MAX_VALUE;
            int i5 = 0;
            for (PipelinedSubpartition pipelinedSubpartition : pipelinedSubpartitionArr) {
                int currentNumberOfBuffers = pipelinedSubpartition.getCurrentNumberOfBuffers();
                i4 = Math.min(i4, currentNumberOfBuffers);
                i5 = Math.max(i5, currentNumberOfBuffers);
            }
            Assert.assertTrue(i5 == i4 || i5 == i4 + 1);
            if (i3 % 74 == 0) {
                fillRandom(pipelinedSubpartitionArr, 3, createMockBuffer);
            }
        }
    }

    @Test
    public void testFairConsumptionRemoteChannelsPreFilled() throws Exception {
        Buffer createMockBuffer = InputChannelTestUtils.createMockBuffer(42);
        FairnessVerifyingInputGate fairnessVerifyingInputGate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions) Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        ConnectionManager createDummyConnectionManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] remoteInputChannelArr = new RemoteInputChannel[37];
        for (int i = 0; i < 37; i++) {
            RemoteInputChannel remoteInputChannel = new RemoteInputChannel(fairnessVerifyingInputGate, i, new ResultPartitionID(), (ConnectionID) Mockito.mock(ConnectionID.class), createDummyConnectionManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            remoteInputChannelArr[i] = remoteInputChannel;
            for (int i2 = 0; i2 < 27; i2++) {
                remoteInputChannel.onBuffer(createMockBuffer, i2);
            }
            remoteInputChannel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), 27);
            fairnessVerifyingInputGate.setInputChannel(new IntermediateResultPartitionID(), remoteInputChannel);
        }
        for (int i3 = 1036; i3 > 0; i3--) {
            Assert.assertNotNull(fairnessVerifyingInputGate.getNextBufferOrEvent());
            int i4 = Integer.MAX_VALUE;
            int i5 = 0;
            for (RemoteInputChannel remoteInputChannel2 : remoteInputChannelArr) {
                int numberOfQueuedBuffers = remoteInputChannel2.getNumberOfQueuedBuffers();
                i4 = Math.min(i4, numberOfQueuedBuffers);
                i5 = Math.max(i5, numberOfQueuedBuffers);
            }
            Assert.assertTrue(i5 == i4 || i5 == i4 + 1);
        }
        Assert.assertNull(fairnessVerifyingInputGate.getNextBufferOrEvent());
    }

    @Test
    public void testFairConsumptionRemoteChannels() throws Exception {
        Buffer createMockBuffer = InputChannelTestUtils.createMockBuffer(42);
        FairnessVerifyingInputGate fairnessVerifyingInputGate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions) Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        ConnectionManager createDummyConnectionManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] remoteInputChannelArr = new RemoteInputChannel[37];
        int[] iArr = new int[37];
        for (int i = 0; i < 37; i++) {
            RemoteInputChannel remoteInputChannel = new RemoteInputChannel(fairnessVerifyingInputGate, i, new ResultPartitionID(), (ConnectionID) Mockito.mock(ConnectionID.class), createDummyConnectionManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            remoteInputChannelArr[i] = remoteInputChannel;
            fairnessVerifyingInputGate.setInputChannel(new IntermediateResultPartitionID(), remoteInputChannel);
        }
        remoteInputChannelArr[11].onBuffer(createMockBuffer, 0);
        iArr[11] = iArr[11] + 1;
        for (int i2 = 0; i2 < 999; i2++) {
            Assert.assertNotNull(fairnessVerifyingInputGate.getNextBufferOrEvent());
            int i3 = Integer.MAX_VALUE;
            int i4 = 0;
            for (RemoteInputChannel remoteInputChannel2 : remoteInputChannelArr) {
                int numberOfQueuedBuffers = remoteInputChannel2.getNumberOfQueuedBuffers();
                i3 = Math.min(i3, numberOfQueuedBuffers);
                i4 = Math.max(i4, numberOfQueuedBuffers);
            }
            Assert.assertTrue(i4 == i3 || i4 == i3 + 1);
            if (i2 % 74 == 0) {
                fillRandom(remoteInputChannelArr, iArr, 3, createMockBuffer);
            }
        }
    }

    private void fillRandom(PipelinedSubpartition[] pipelinedSubpartitionArr, int i, Buffer buffer) throws Exception {
        ArrayList arrayList = new ArrayList(pipelinedSubpartitionArr.length * i);
        for (int i2 = 0; i2 < pipelinedSubpartitionArr.length; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        Collections.shuffle(arrayList);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            pipelinedSubpartitionArr[((Integer) it.next()).intValue()].add(buffer);
        }
    }

    private void fillRandom(RemoteInputChannel[] remoteInputChannelArr, int[] iArr, int i, Buffer buffer) throws Exception {
        ArrayList arrayList = new ArrayList(remoteInputChannelArr.length * i);
        for (int i2 = 0; i2 < remoteInputChannelArr.length; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        Collections.shuffle(arrayList);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            RemoteInputChannel remoteInputChannel = remoteInputChannelArr[intValue];
            int i4 = iArr[intValue];
            iArr[intValue] = i4 + 1;
            remoteInputChannel.onBuffer(buffer, i4);
        }
    }
}
