package org.apache.flink.optimizer.dataexchange;

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.ExecutionMode;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.operators.UnionOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.plan.SourcePlanNode;
import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.runtime.io.network.DataExchangeMode;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.class */
public class UnionClosedBranchingTest extends CompilerTestBase {
    private final ExecutionMode executionMode;
    private final DataExchangeMode sourceToUnion;
    private final DataExchangeMode unionToJoin;

    @Parameterized.Parameters
    public static Collection<Object[]> params() {
        List asList = Arrays.asList(new Object[]{ExecutionMode.PIPELINED, DataExchangeMode.PIPELINED, DataExchangeMode.BATCH}, new Object[]{ExecutionMode.PIPELINED_FORCED, DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED}, new Object[]{ExecutionMode.BATCH, DataExchangeMode.BATCH, DataExchangeMode.BATCH}, new Object[]{ExecutionMode.BATCH_FORCED, DataExchangeMode.BATCH, DataExchangeMode.BATCH});
        Assert.assertEquals(ExecutionMode.values().length, asList.size());
        return asList;
    }

    public UnionClosedBranchingTest(ExecutionMode executionMode, DataExchangeMode dataExchangeMode, DataExchangeMode dataExchangeMode2) {
        this.executionMode = executionMode;
        this.sourceToUnion = dataExchangeMode;
        this.unionToJoin = dataExchangeMode2;
    }

    @Test
    public void testUnionClosedBranchingTest() throws Exception {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.getConfig().setExecutionMode(this.executionMode);
        executionEnvironment.setParallelism(4);
        UnionOperator union = executionEnvironment.fromElements(new Tuple1[]{new Tuple1(0), new Tuple1(1)}).union(executionEnvironment.fromElements(new Tuple1[]{new Tuple1(0), new Tuple1(1)}));
        union.join(union).where(new int[]{0}).equalTo(new int[]{0}).projectFirst(new int[]{0}).projectSecond(new int[]{0}).output(new DiscardingOutputFormat());
        OptimizedPlan compileNoStats = compileNoStats(executionEnvironment.createProgramPlan());
        Iterator it = ((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getPredecessor().getInputs().iterator();
        while (it.hasNext()) {
            Assert.assertEquals("Unexpected data exchange mode between union and join node.", this.unionToJoin, ((Channel) it.next()).getDataExchangeMode());
        }
        Iterator it2 = compileNoStats.getDataSources().iterator();
        while (it2.hasNext()) {
            Iterator it3 = ((SourcePlanNode) it2.next()).getOutgoingChannels().iterator();
            while (it3.hasNext()) {
                Assert.assertEquals("Unexpected data exchange mode between source and union node.", this.sourceToUnion, ((Channel) it3.next()).getDataExchangeMode());
            }
        }
        List verticesSortedTopologicallyFromSources = new JobGraphGenerator().compileJobGraph(compileNoStats).getVerticesSortedTopologicallyFromSources();
        Assert.assertEquals("Unexpected number of vertices created.", 4L, verticesSortedTopologicallyFromSources.size());
        for (JobVertex jobVertex : new JobVertex[]{(JobVertex) verticesSortedTopologicallyFromSources.get(0), (JobVertex) verticesSortedTopologicallyFromSources.get(1)}) {
            Assert.assertTrue("Unexpected vertex type. Test setup is broken.", jobVertex.isInputVertex());
            Assert.assertEquals("Unexpected number of created results.", 2L, jobVertex.getNumberOfProducedIntermediateDataSets());
            Iterator it4 = jobVertex.getProducedDataSets().iterator();
            while (it4.hasNext()) {
                ResultPartitionType resultType = ((IntermediateDataSet) it4.next()).getResultType();
                if (this.unionToJoin.equals(DataExchangeMode.BATCH)) {
                    Assert.assertTrue("Expected batch exchange, but result type is " + resultType + ".", resultType.isBlocking());
                } else {
                    Assert.assertFalse("Expected non-batch exchange, but result type is " + resultType + ".", resultType.isBlocking());
                }
            }
        }
    }
}
