package org.apache.flink.test.scheduling;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.SlotSharingGroup;
import org.apache.flink.api.connector.source.DynamicParallelismInference;
import org.apache.flink.api.connector.source.lib.NumberSequenceSource;
import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.configuration.ClusterOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.RestOptions;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.LocalStreamEnvironment;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.class */
class AdaptiveBatchSchedulerITCase {
    private static final int DEFAULT_MAX_PARALLELISM = 4;
    private static final int SOURCE_PARALLELISM_1 = 2;
    private static final int SOURCE_PARALLELISM_2 = 8;
    private static final int NUMBERS_TO_PRODUCE = 10000;
    private static ConcurrentLinkedQueue<Map<Long, Long>> numberCountResults;
    private Map<Long, Long> expectedResult;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase$NumberCounter.class */
    public static class NumberCounter extends RichMapFunction<Long, Long> {
        private final Map<Long, Long> numberCountResult;

        private NumberCounter() {
            this.numberCountResult = new HashMap();
        }

        public Long map(Long l) throws Exception {
            this.numberCountResult.put(l, Long.valueOf(this.numberCountResult.getOrDefault(l, 0L).longValue() + 1));
            return l;
        }

        public void close() throws Exception {
            AdaptiveBatchSchedulerITCase.numberCountResults.add(this.numberCountResult);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase$TestingParallelismInferenceNumberSequenceSource.class */
    public static class TestingParallelismInferenceNumberSequenceSource extends NumberSequenceSource implements DynamicParallelismInference {
        private static final long serialVersionUID = 1;
        private final int expectedParallelism;

        public TestingParallelismInferenceNumberSequenceSource(long j, long j2, int i) {
            super(j, j2);
            this.expectedParallelism = i;
        }

        public int inferParallelism(DynamicParallelismInference.Context context) {
            return this.expectedParallelism;
        }
    }

    AdaptiveBatchSchedulerITCase() {
    }

    @BeforeEach
    void setUp() {
        this.expectedResult = (Map) LongStream.range(0L, 10000L).boxed().collect(Collectors.toMap(Function.identity(), l -> {
            return 2L;
        }));
        numberCountResults = new ConcurrentLinkedQueue<>();
    }

    @Test
    void testScheduling() throws Exception {
        testSchedulingBase(false);
    }

    @Test
    void testSchedulingWithDynamicSourceParallelismInference() throws Exception {
        testSchedulingBase(true);
    }

    @Test
    void testParallelismOfForwardGroupLargerThanGlobalMaxParallelism() throws Exception {
        LocalStreamEnvironment createLocalEnvironment = StreamExecutionEnvironment.createLocalEnvironment(createConfiguration());
        createLocalEnvironment.setRuntimeMode(RuntimeExecutionMode.BATCH);
        createLocalEnvironment.setParallelism(SOURCE_PARALLELISM_2);
        createLocalEnvironment.fromSequence(0L, 9999L).setParallelism(SOURCE_PARALLELISM_2).name("source").slotSharingGroup("group1").forward().map(new NumberCounter()).name("map").slotSharingGroup("group2");
        createLocalEnvironment.execute();
    }

    @Test
    void testDifferentConsumerParallelism() throws Exception {
        LocalStreamEnvironment createLocalEnvironment = StreamExecutionEnvironment.createLocalEnvironment(createConfiguration());
        createLocalEnvironment.setRuntimeMode(RuntimeExecutionMode.BATCH);
        createLocalEnvironment.setParallelism(SOURCE_PARALLELISM_2);
        DataStream slotSharingGroup = createLocalEnvironment.fromSequence(0L, 9999L).setParallelism(SOURCE_PARALLELISM_2).name("source2").slotSharingGroup("group2");
        createLocalEnvironment.fromSequence(0L, 9999L).setParallelism(SOURCE_PARALLELISM_2).name("source1").slotSharingGroup("group1").forward().union(new DataStream[]{slotSharingGroup}).map(new NumberCounter()).name("map1").slotSharingGroup("group3");
        slotSharingGroup.map(new NumberCounter()).name("map2").slotSharingGroup("group4");
        createLocalEnvironment.execute();
    }

    private void testSchedulingBase(Boolean bool) throws Exception {
        executeJob(bool);
        Map map = (Map) numberCountResults.stream().flatMap(map2 -> {
            return map2.entrySet().stream();
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }, (l, l2) -> {
            return Long.valueOf(l.longValue() + l2.longValue());
        }));
        for (int i = 0; i < NUMBERS_TO_PRODUCE; i++) {
            if (map.get(Integer.valueOf(i)) != this.expectedResult.get(Integer.valueOf(i))) {
                System.out.println(i + ": " + map.get(Integer.valueOf(i)));
            }
        }
        Assertions.assertThat(map).isEqualTo(this.expectedResult);
    }

    private void executeJob(Boolean bool) throws Exception {
        SingleOutputStreamOperator slotSharingGroup;
        SingleOutputStreamOperator slotSharingGroup2;
        Configuration createConfiguration = createConfiguration();
        createConfiguration.set(ClusterOptions.FINE_GRAINED_SHUFFLE_MODE_ALL_BLOCKING, true);
        LocalStreamEnvironment createLocalEnvironment = StreamExecutionEnvironment.createLocalEnvironment(createConfiguration);
        createLocalEnvironment.setRuntimeMode(RuntimeExecutionMode.BATCH);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 3; i++) {
            arrayList.add(SlotSharingGroup.newBuilder("group" + i).setCpuCores(1.0d).setTaskHeapMemory(MemorySize.parse("100m")).build());
        }
        if (bool.booleanValue()) {
            slotSharingGroup = createLocalEnvironment.fromSource(new TestingParallelismInferenceNumberSequenceSource(0L, 9999L, SOURCE_PARALLELISM_1), WatermarkStrategy.noWatermarks(), "source1").slotSharingGroup((SlotSharingGroup) arrayList.get(0));
            slotSharingGroup2 = createLocalEnvironment.fromSource(new TestingParallelismInferenceNumberSequenceSource(0L, 9999L, SOURCE_PARALLELISM_2), WatermarkStrategy.noWatermarks(), "source2").slotSharingGroup((SlotSharingGroup) arrayList.get(1));
        } else {
            slotSharingGroup = createLocalEnvironment.fromSequence(0L, 9999L).setParallelism(SOURCE_PARALLELISM_1).name("source1").slotSharingGroup((SlotSharingGroup) arrayList.get(0));
            slotSharingGroup2 = createLocalEnvironment.fromSequence(0L, 9999L).setParallelism(SOURCE_PARALLELISM_2).name("source2").slotSharingGroup((SlotSharingGroup) arrayList.get(1));
        }
        slotSharingGroup.union(new DataStream[]{slotSharingGroup2}).rescale().map(new NumberCounter()).name("map").slotSharingGroup((SlotSharingGroup) arrayList.get(SOURCE_PARALLELISM_1));
        createLocalEnvironment.execute();
    }

    private static Configuration createConfiguration() {
        Configuration configuration = new Configuration();
        configuration.set(RestOptions.BIND_PORT, "0");
        configuration.set(JobManagerOptions.SLOT_REQUEST_TIMEOUT, Duration.ofMillis(5000L));
        configuration.set(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MAX_PARALLELISM, Integer.valueOf(DEFAULT_MAX_PARALLELISM));
        configuration.set(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK, MemorySize.parse("150kb"));
        configuration.set(TaskManagerOptions.MEMORY_SEGMENT_SIZE, MemorySize.parse("4kb"));
        configuration.set(TaskManagerOptions.NUM_TASK_SLOTS, 1);
        return configuration;
    }
}
