/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IOMetrics;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.scheduler.adaptivebatch.DefaultVertexParallelismDeciderTest;
import org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismDecider;
import org.apache.flink.runtime.taskmanager.TaskExecutionState;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class AdaptiveBatchSchedulerTest {
    private static final int SOURCE_PARALLELISM_1 = 6;
    private static final int SOURCE_PARALLELISM_2 = 4;
    private static final long PARTITION_BYTES = 100L;
    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();
    private static final ComponentMainThreadExecutor mainThreadExecutor = ComponentMainThreadExecutorServiceAdapter.forMainThread();

    AdaptiveBatchSchedulerTest() {
    }

    @Test
    void testAdaptiveBatchScheduler() throws Exception {
        JobGraph jobGraph = this.createJobGraph(false);
        Iterator jobVertexIterator = jobGraph.getVertices().iterator();
        JobVertex source1 = (JobVertex)jobVertexIterator.next();
        JobVertex source2 = (JobVertex)jobVertexIterator.next();
        JobVertex sink = (JobVertex)jobVertexIterator.next();
        SchedulerBase scheduler = this.createScheduler(jobGraph);
        DefaultExecutionGraph graph = (DefaultExecutionGraph)scheduler.getExecutionGraph();
        ExecutionJobVertex sinkExecutionJobVertex = graph.getJobVertex(sink.getID());
        scheduler.startScheduling();
        Assertions.assertThat((int)sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
        AdaptiveBatchSchedulerTest.transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
        Assertions.assertThat((int)sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
        AdaptiveBatchSchedulerTest.transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
        Assertions.assertThat((int)sinkExecutionJobVertex.getParallelism()).isEqualTo(10);
        Assertions.assertThat((int)sink.getParallelism()).isEqualTo(10);
    }

    @Test
    void testDecideParallelismForForwardTarget() throws Exception {
        JobGraph jobGraph = this.createJobGraph(true);
        Iterator jobVertexIterator = jobGraph.getVertices().iterator();
        JobVertex source1 = (JobVertex)jobVertexIterator.next();
        JobVertex source2 = (JobVertex)jobVertexIterator.next();
        JobVertex sink = (JobVertex)jobVertexIterator.next();
        SchedulerBase scheduler = this.createScheduler(jobGraph);
        DefaultExecutionGraph graph = (DefaultExecutionGraph)scheduler.getExecutionGraph();
        ExecutionJobVertex sinkExecutionJobVertex = graph.getJobVertex(sink.getID());
        scheduler.startScheduling();
        Assertions.assertThat((int)sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
        AdaptiveBatchSchedulerTest.transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
        Assertions.assertThat((int)sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
        AdaptiveBatchSchedulerTest.transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
        Assertions.assertThat((int)sinkExecutionJobVertex.getParallelism()).isEqualTo(6);
        Assertions.assertThat((int)sink.getParallelism()).isEqualTo(6);
    }

    @Test
    void testUserConfiguredMaxParallelismIsLargerThanGlobalMaxParallelism() throws Exception {
        this.testUserConfiguredMaxParallelism(1, 32, 128, 1L, 32);
    }

    @Test
    void testUserConfiguredMaxParallelismIsSmallerThanGlobalMaxParallelism() throws Exception {
        this.testUserConfiguredMaxParallelism(1, 128, 32, 1L, 32);
    }

    @Test
    void testUserConfiguredMaxParallelismIsSmallerThanGlobalMinParallelism() throws Exception {
        this.testUserConfiguredMaxParallelism(16, 128, 8, 400L, 8);
    }

    @Test
    void testUserConfiguredMaxParallelismIsSmallerThanGlobalDefaultSourceParallelism() throws Exception {
        JobVertex source = this.createJobVertex("source", -1);
        source.setMaxParallelism(8);
        SchedulerBase scheduler = this.createScheduler(new JobGraph(new JobID(), "test job", new JobVertex[]{source}), (VertexParallelismDecider)DefaultVertexParallelismDeciderTest.createDecider(1, 128, 1L, 32), 128);
        scheduler.startScheduling();
        Assertions.assertThat((int)source.getParallelism()).isEqualTo(8);
    }

    void testUserConfiguredMaxParallelism(int globalMinParallelism, int globalMaxParallelism, int userConfiguredMaxParallelism, long dataVolumePerTask, int expectedParallelism) throws Exception {
        JobVertex source = this.createJobVertex("source", 8);
        JobVertex sink = this.createJobVertex("sink", -1);
        sink.setMaxParallelism(userConfiguredMaxParallelism);
        sink.connectNewDataSetAsInput(source, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        SchedulerBase scheduler = this.createScheduler(new JobGraph(new JobID(), "test job", new JobVertex[]{source, sink}), (VertexParallelismDecider)DefaultVertexParallelismDeciderTest.createDecider(globalMinParallelism, globalMaxParallelism, dataVolumePerTask, 10), globalMaxParallelism);
        scheduler.startScheduling();
        AdaptiveBatchSchedulerTest.transitionExecutionsState(scheduler, ExecutionState.FINISHED, source);
        Assertions.assertThat((int)sink.getParallelism()).isEqualTo(expectedParallelism);
    }

    public static void transitionExecutionsState(SchedulerBase scheduler, ExecutionState state, List<Execution> executions) {
        for (Execution execution : executions) {
            IOMetrics ioMetrics = new IOMetrics(0L, 0L, 0L, 0L, 0L, 0L, 0L);
            ioMetrics.getNumBytesProducedOfPartitions().putAll(AdaptiveBatchSchedulerTest.createResultPartitionBytesForExecution(execution));
            scheduler.updateTaskExecutionState(new TaskExecutionState(execution.getAttemptId(), state, null, null, ioMetrics));
        }
    }

    public static void transitionExecutionsState(SchedulerBase scheduler, ExecutionState state, JobVertex jobVertex) {
        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
        List<Execution> executions = Arrays.asList(executionGraph.getJobVertex(jobVertex.getID()).getTaskVertices()).stream().map(ExecutionVertex::getCurrentExecutionAttempt).collect(Collectors.toList());
        AdaptiveBatchSchedulerTest.transitionExecutionsState(scheduler, state, executions);
    }

    static Map<IntermediateResultPartitionID, Long> createResultPartitionBytesForExecution(Execution execution) {
        HashMap<IntermediateResultPartitionID, Long> partitionBytes = new HashMap<IntermediateResultPartitionID, Long>();
        execution.getVertex().getProducedPartitions().forEach((partitionId, partition) -> partitionBytes.put((IntermediateResultPartitionID)partitionId, 100L));
        return partitionBytes;
    }

    public JobVertex createJobVertex(String jobVertexName, int parallelism) {
        JobVertex jobVertex = new JobVertex(jobVertexName);
        jobVertex.setInvokableClass(NoOpInvokable.class);
        if (parallelism > 0) {
            jobVertex.setParallelism(parallelism);
        }
        return jobVertex;
    }

    public JobGraph createJobGraph(boolean withForwardEdge) {
        JobVertex source1 = this.createJobVertex("source1", 6);
        JobVertex source2 = this.createJobVertex("source2", 4);
        JobVertex sink = this.createJobVertex("sink", -1);
        sink.connectNewDataSetAsInput(source1, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        sink.connectNewDataSetAsInput(source2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        if (withForwardEdge) {
            ((JobEdge)((IntermediateDataSet)source1.getProducedDataSets().get(0)).getConsumers().get(0)).setForward(true);
        }
        return new JobGraph(new JobID(), "test job", new JobVertex[]{source1, source2, sink});
    }

    public SchedulerBase createScheduler(JobGraph jobGraph) throws Exception {
        return this.createScheduler(jobGraph, (ignoredA, ignoredB, ignoredC) -> 10, (Integer)JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MAX_PARALLELISM.defaultValue());
    }

    private SchedulerBase createScheduler(JobGraph jobGraph, VertexParallelismDecider vertexParallelismDecider, int defaultMaxParallelism) throws Exception {
        Configuration configuration = new Configuration();
        configuration.set(JobManagerOptions.SCHEDULER, (Object)JobManagerOptions.SchedulerType.AdaptiveBatch);
        return new DefaultSchedulerBuilder(jobGraph, mainThreadExecutor, (ScheduledExecutorService)EXECUTOR_RESOURCE.getExecutor()).setVertexParallelismDecider(vertexParallelismDecider).setDefaultMaxParallelism(defaultMaxParallelism).buildAdaptiveBatchJobScheduler();
    }
}

