package org.apache.flink.runtime.io.network.partition;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
import scala.Tuple2;

/* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.class */
public class InputGateConcurrentTest {

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest$CheckedThread.class */
    private static abstract class CheckedThread extends Thread {
        private volatile Throwable error;

        private CheckedThread() {
        }

        public abstract void go() throws Exception;

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            try {
                go();
            } catch (Throwable th) {
                this.error = th;
            }
        }

        public void sync() throws Exception {
            join();
            if (this.error != null) {
                if (this.error instanceof Error) {
                    throw ((Error) this.error);
                }
                if (!(this.error instanceof Exception)) {
                    throw new Exception(this.error.getMessage(), this.error);
                }
                throw ((Exception) this.error);
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest$ConsumerThread.class */
    private static class ConsumerThread extends CheckedThread {
        private final SingleInputGate gate;
        private final int numBuffers;

        ConsumerThread(SingleInputGate singleInputGate, int i) {
            super();
            this.gate = singleInputGate;
            this.numBuffers = i;
        }

        @Override // org.apache.flink.runtime.io.network.partition.InputGateConcurrentTest.CheckedThread
        public void go() throws Exception {
            for (int i = this.numBuffers; i > 0; i--) {
                Assert.assertNotNull(this.gate.getNextBufferOrEvent());
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest$PipelinedSubpartitionSource.class */
    private static class PipelinedSubpartitionSource extends Source {
        final PipelinedSubpartition partition;

        PipelinedSubpartitionSource(PipelinedSubpartition pipelinedSubpartition) {
            super();
            this.partition = pipelinedSubpartition;
        }

        @Override // org.apache.flink.runtime.io.network.partition.InputGateConcurrentTest.Source
        void addBuffer(Buffer buffer) throws Exception {
            this.partition.add(buffer);
        }
    }

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest$ProducerThread.class */
    private static class ProducerThread extends CheckedThread {
        private final Random rnd;
        private final Source[] sources;
        private final int numTotal;
        private final int maxChunk;
        private final int yieldAfter;

        ProducerThread(Source[] sourceArr, int i, int i2, int i3) {
            super();
            this.rnd = new Random();
            this.sources = sourceArr;
            this.numTotal = i;
            this.maxChunk = i2;
            this.yieldAfter = i3;
        }

        @Override // org.apache.flink.runtime.io.network.partition.InputGateConcurrentTest.CheckedThread
        public void go() throws Exception {
            Buffer createMockBuffer = InputChannelTestUtils.createMockBuffer(100);
            int i = this.numTotal - this.yieldAfter;
            int i2 = this.numTotal;
            while (i2 > 0) {
                int nextInt = this.rnd.nextInt(this.sources.length);
                int min = Math.min(i2, this.rnd.nextInt(this.maxChunk) + 1);
                Source source = this.sources[nextInt];
                for (int i3 = min; i3 > 0; i3--) {
                    source.addBuffer(createMockBuffer);
                }
                i2 -= min;
                if (i2 <= i) {
                    i -= this.yieldAfter;
                    Thread.yield();
                }
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest$RemoteChannelSource.class */
    private static class RemoteChannelSource extends Source {
        final RemoteInputChannel channel;
        private int seq;

        RemoteChannelSource(RemoteInputChannel remoteInputChannel) {
            super();
            this.seq = 0;
            this.channel = remoteInputChannel;
        }

        @Override // org.apache.flink.runtime.io.network.partition.InputGateConcurrentTest.Source
        void addBuffer(Buffer buffer) throws Exception {
            RemoteInputChannel remoteInputChannel = this.channel;
            int i = this.seq;
            this.seq = i + 1;
            remoteInputChannel.onBuffer(buffer, i);
        }
    }

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest$Source.class */
    private static abstract class Source {
        private Source() {
        }

        abstract void addBuffer(Buffer buffer) throws Exception;
    }

    @Test
    public void testConsumptionWithLocalChannels() throws Exception {
        ResultPartition resultPartition = (ResultPartition) Mockito.mock(ResultPartition.class);
        PipelinedSubpartition[] pipelinedSubpartitionArr = new PipelinedSubpartition[11];
        Source[] sourceArr = new Source[11];
        ResultPartitionManager createResultPartitionManager = InputChannelTestUtils.createResultPartitionManager(pipelinedSubpartitionArr);
        SingleInputGate singleInputGate = new SingleInputGate("Test Task Name", new JobID(), new ExecutionAttemptID(), new IntermediateDataSetID(), 0, 11, (PartitionProducerStateChecker) Mockito.mock(PartitionProducerStateChecker.class), new UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
        for (int i = 0; i < 11; i++) {
            singleInputGate.setInputChannel(new IntermediateResultPartitionID(), new LocalInputChannel(singleInputGate, i, new ResultPartitionID(), createResultPartitionManager, (TaskEventDispatcher) Mockito.mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()));
            pipelinedSubpartitionArr[i] = new PipelinedSubpartition(0, resultPartition);
            sourceArr[i] = new PipelinedSubpartitionSource(pipelinedSubpartitionArr[i]);
        }
        ProducerThread producerThread = new ProducerThread(sourceArr, 11000, 4, 10);
        ConsumerThread consumerThread = new ConsumerThread(singleInputGate, 11000);
        producerThread.start();
        consumerThread.start();
        producerThread.sync();
        consumerThread.sync();
    }

    @Test
    public void testConsumptionWithRemoteChannels() throws Exception {
        ConnectionManager createDummyConnectionManager = InputChannelTestUtils.createDummyConnectionManager();
        Source[] sourceArr = new Source[11];
        SingleInputGate singleInputGate = new SingleInputGate("Test Task Name", new JobID(), new ExecutionAttemptID(), new IntermediateDataSetID(), 0, 11, (PartitionProducerStateChecker) Mockito.mock(PartitionProducerStateChecker.class), new UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
        for (int i = 0; i < 11; i++) {
            RemoteInputChannel remoteInputChannel = new RemoteInputChannel(singleInputGate, i, new ResultPartitionID(), (ConnectionID) Mockito.mock(ConnectionID.class), createDummyConnectionManager, new Tuple2(0, 0), new UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
            singleInputGate.setInputChannel(new IntermediateResultPartitionID(), remoteInputChannel);
            sourceArr[i] = new RemoteChannelSource(remoteInputChannel);
        }
        ProducerThread producerThread = new ProducerThread(sourceArr, 11000, 4, 10);
        ConsumerThread consumerThread = new ConsumerThread(singleInputGate, 11000);
        producerThread.start();
        consumerThread.start();
        producerThread.sync();
        consumerThread.sync();
    }

    @Test
    public void testConsumptionWithMixedChannels() throws Exception {
        ArrayList arrayList = new ArrayList(61);
        int i = 0;
        while (i < 61) {
            arrayList.add(Boolean.valueOf(i < 20));
            i++;
        }
        Collections.shuffle(arrayList);
        ConnectionManager createDummyConnectionManager = InputChannelTestUtils.createDummyConnectionManager();
        ResultPartition resultPartition = (ResultPartition) Mockito.mock(ResultPartition.class);
        PipelinedSubpartition[] pipelinedSubpartitionArr = new PipelinedSubpartition[20];
        ResultPartitionManager createResultPartitionManager = InputChannelTestUtils.createResultPartitionManager(pipelinedSubpartitionArr);
        Source[] sourceArr = new Source[61];
        SingleInputGate singleInputGate = new SingleInputGate("Test Task Name", new JobID(), new ExecutionAttemptID(), new IntermediateDataSetID(), 0, 61, (PartitionProducerStateChecker) Mockito.mock(PartitionProducerStateChecker.class), new UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
        int i2 = 0;
        for (int i3 = 0; i3 < 61; i3++) {
            if (((Boolean) arrayList.get(i3)).booleanValue()) {
                PipelinedSubpartition pipelinedSubpartition = new PipelinedSubpartition(0, resultPartition);
                int i4 = i2;
                i2++;
                pipelinedSubpartitionArr[i4] = pipelinedSubpartition;
                sourceArr[i3] = new PipelinedSubpartitionSource(pipelinedSubpartition);
                singleInputGate.setInputChannel(new IntermediateResultPartitionID(), new LocalInputChannel(singleInputGate, i3, new ResultPartitionID(), createResultPartitionManager, (TaskEventDispatcher) Mockito.mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()));
            } else {
                RemoteInputChannel remoteInputChannel = new RemoteInputChannel(singleInputGate, i3, new ResultPartitionID(), (ConnectionID) Mockito.mock(ConnectionID.class), createDummyConnectionManager, new Tuple2(0, 0), new UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
                singleInputGate.setInputChannel(new IntermediateResultPartitionID(), remoteInputChannel);
                sourceArr[i3] = new RemoteChannelSource(remoteInputChannel);
            }
        }
        ProducerThread producerThread = new ProducerThread(sourceArr, 61000, 4, 10);
        ConsumerThread consumerThread = new ConsumerThread(singleInputGate, 61000);
        producerThread.start();
        consumerThread.start();
        producerThread.sync();
        consumerThread.sync();
    }
}
