package org.apache.flink.optimizer.java;

import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.DualInputPlanNode;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/optimizer/java/WorksetIterationsJavaApiCompilerTest.class */
public class WorksetIterationsJavaApiCompilerTest extends CompilerTestBase {
    private static final String JOIN_WITH_INVARIANT_NAME = "Test Join Invariant";
    private static final String JOIN_WITH_SOLUTION_SET = "Test Join SolutionSet";
    private static final String NEXT_WORKSET_REDUCER_NAME = "Test Reduce Workset";
    private static final String SOLUTION_DELTA_MAPPER_NAME = "Test Map Delta";

    @Test
    public void testJavaApiWithDeferredSoltionSetUpdateWithMapper() {
        try {
            OptimizedPlan compileNoStats = compileNoStats(getJavaTestPlan(false, true));
            CompilerTestBase.OptimizerPlanNodeResolver optimizerPlanNodeResolver = getOptimizerPlanNodeResolver(compileNoStats);
            DualInputPlanNode node = optimizerPlanNodeResolver.getNode(JOIN_WITH_INVARIANT_NAME);
            DualInputPlanNode node2 = optimizerPlanNodeResolver.getNode(JOIN_WITH_SOLUTION_SET);
            SingleInputPlanNode node3 = optimizerPlanNodeResolver.getNode(NEXT_WORKSET_REDUCER_NAME);
            SingleInputPlanNode node4 = optimizerPlanNodeResolver.getNode(SOLUTION_DELTA_MAPPER_NAME);
            Assert.assertEquals(ShipStrategyType.FORWARD, node.getInput1().getShipStrategy());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node.getInput2().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node.getKeysForInput1());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node.getKeysForInput2());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node2.getInput1().getShipStrategy());
            Assert.assertEquals(ShipStrategyType.FORWARD, node2.getInput2().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 0}), node2.getKeysForInput1());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node3.getInput().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node3.getKeys(0));
            ShipStrategyType shipStrategy = node4.getInput().getShipStrategy();
            ShipStrategyType shipStrategy2 = ((Channel) node4.getOutgoingChannels().get(0)).getShipStrategy();
            Assert.assertTrue((shipStrategy == ShipStrategyType.FORWARD && shipStrategy2 == ShipStrategyType.PARTITION_HASH) || (shipStrategy2 == ShipStrategyType.FORWARD && shipStrategy == ShipStrategyType.PARTITION_HASH));
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail("Test errored: " + e.getMessage());
        }
    }

    @Test
    public void testJavaApiWithDeferredSoltionSetUpdateWithNonPreservingJoin() {
        try {
            OptimizedPlan compileNoStats = compileNoStats(getJavaTestPlan(false, false));
            CompilerTestBase.OptimizerPlanNodeResolver optimizerPlanNodeResolver = getOptimizerPlanNodeResolver(compileNoStats);
            DualInputPlanNode node = optimizerPlanNodeResolver.getNode(JOIN_WITH_INVARIANT_NAME);
            DualInputPlanNode node2 = optimizerPlanNodeResolver.getNode(JOIN_WITH_SOLUTION_SET);
            SingleInputPlanNode node3 = optimizerPlanNodeResolver.getNode(NEXT_WORKSET_REDUCER_NAME);
            Assert.assertEquals(ShipStrategyType.FORWARD, node.getInput1().getShipStrategy());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node.getInput2().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node.getKeysForInput1());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node.getKeysForInput2());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node2.getInput1().getShipStrategy());
            Assert.assertEquals(ShipStrategyType.FORWARD, node2.getInput2().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 0}), node2.getKeysForInput1());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node3.getInput().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node3.getKeys(0));
            Assert.assertEquals(2L, node2.getOutgoingChannels().size());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, ((Channel) node2.getOutgoingChannels().get(0)).getShipStrategy());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, ((Channel) node2.getOutgoingChannels().get(1)).getShipStrategy());
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail("Test errored: " + e.getMessage());
        }
    }

    @Test
    public void testJavaApiWithDirectSoltionSetUpdate() {
        try {
            OptimizedPlan compileNoStats = compileNoStats(getJavaTestPlan(true, false));
            CompilerTestBase.OptimizerPlanNodeResolver optimizerPlanNodeResolver = getOptimizerPlanNodeResolver(compileNoStats);
            DualInputPlanNode node = optimizerPlanNodeResolver.getNode(JOIN_WITH_INVARIANT_NAME);
            DualInputPlanNode node2 = optimizerPlanNodeResolver.getNode(JOIN_WITH_SOLUTION_SET);
            SingleInputPlanNode node3 = optimizerPlanNodeResolver.getNode(NEXT_WORKSET_REDUCER_NAME);
            Assert.assertEquals(ShipStrategyType.FORWARD, node.getInput1().getShipStrategy());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node.getInput2().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node.getKeysForInput1());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node.getKeysForInput2());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node2.getInput1().getShipStrategy());
            Assert.assertEquals(ShipStrategyType.FORWARD, node2.getInput2().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 0}), node2.getKeysForInput1());
            Assert.assertEquals(ShipStrategyType.FORWARD, node3.getInput().getShipStrategy());
            Assert.assertEquals(new FieldList(new int[]{1, 2}), node3.getKeys(0));
            Assert.assertEquals(1L, node2.getOutgoingChannels().size());
            Assert.assertEquals(ShipStrategyType.FORWARD, ((Channel) node2.getOutgoingChannels().get(0)).getShipStrategy());
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail("Test errored: " + e.getMessage());
        }
    }

    @Test
    public void testRejectPlanIfSolutionSetKeysAndJoinKeysDontMatch() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(8);
            Operator name = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(1L, 2L, 3L)}).name("Solution Set");
            Operator name2 = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(1L, 2L, 3L)}).name("Workset");
            Operator name3 = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(1L, 2L, 3L)}).name("Invariant Input");
            DeltaIteration iterateDelta = name.iterateDelta(name2, 100, new int[]{1, 2});
            try {
                iterateDelta.getWorkset().join(name3).where(new int[]{1, 2}).equalTo(new int[]{1, 2}).with(new JoinFunction<Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>>() { // from class: org.apache.flink.optimizer.java.WorksetIterationsJavaApiCompilerTest.1
                    public Tuple3<Long, Long, Long> join(Tuple3<Long, Long, Long> tuple3, Tuple3<Long, Long, Long> tuple32) {
                        return tuple3;
                    }
                }).join(iterateDelta.getSolutionSet()).where(new int[]{1, 0}).equalTo(new int[]{0, 2}).with(new JoinFunction<Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>>() { // from class: org.apache.flink.optimizer.java.WorksetIterationsJavaApiCompilerTest.2
                    public Tuple3<Long, Long, Long> join(Tuple3<Long, Long, Long> tuple3, Tuple3<Long, Long, Long> tuple32) {
                        return tuple32;
                    }
                });
                Assert.fail("The join should be rejected with key type mismatches.");
            } catch (InvalidProgramException e) {
            }
        } catch (Exception e2) {
            System.err.println(e2.getMessage());
            e2.printStackTrace();
            Assert.fail("Test errored: " + e2.getMessage());
        }
    }

    private Plan getJavaTestPlan(boolean z, boolean z2) {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(8);
        Operator name = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(1L, 2L, 3L)}).name("Solution Set");
        Operator name2 = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(1L, 2L, 3L)}).name("Workset");
        Operator name3 = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(1L, 2L, 3L)}).name("Invariant Input");
        DeltaIteration iterateDelta = name.iterateDelta(name2, 100, new int[]{1, 2});
        SingleInputUdfOperator withForwardedFieldsSecond = iterateDelta.getWorkset().join(name3).where(new int[]{1, 2}).equalTo(new int[]{1, 2}).with(new RichJoinFunction<Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>>() { // from class: org.apache.flink.optimizer.java.WorksetIterationsJavaApiCompilerTest.4
            public Tuple3<Long, Long, Long> join(Tuple3<Long, Long, Long> tuple3, Tuple3<Long, Long, Long> tuple32) {
                return tuple3;
            }
        }).name(JOIN_WITH_INVARIANT_NAME).join(iterateDelta.getSolutionSet()).where(new int[]{1, 0}).equalTo(new int[]{1, 2}).with(new RichJoinFunction<Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>>() { // from class: org.apache.flink.optimizer.java.WorksetIterationsJavaApiCompilerTest.3
            public Tuple3<Long, Long, Long> join(Tuple3<Long, Long, Long> tuple3, Tuple3<Long, Long, Long> tuple32) {
                return tuple32;
            }
        }).name(JOIN_WITH_SOLUTION_SET).withForwardedFieldsSecond(z ? new String[]{"0->0", "1->1", "2->2"} : null);
        iterateDelta.closeWith(z2 ? withForwardedFieldsSecond.map(new RichMapFunction<Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>>() { // from class: org.apache.flink.optimizer.java.WorksetIterationsJavaApiCompilerTest.6
            public Tuple3<Long, Long, Long> map(Tuple3<Long, Long, Long> tuple3) {
                return tuple3;
            }
        }).name(SOLUTION_DELTA_MAPPER_NAME).withForwardedFields(new String[]{"0->0", "1->1", "2->2"}) : withForwardedFieldsSecond, withForwardedFieldsSecond.groupBy(new int[]{1, 2}).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Long, Long>, Tuple3<Long, Long, Long>>() { // from class: org.apache.flink.optimizer.java.WorksetIterationsJavaApiCompilerTest.5
            public void reduce(Iterable<Tuple3<Long, Long, Long>> iterable, Collector<Tuple3<Long, Long, Long>> collector) {
            }
        }).name(NEXT_WORKSET_REDUCER_NAME).withForwardedFields(new String[]{"1->1", "2->2", "0->0"})).output(new DiscardingOutputFormat());
        return executionEnvironment.createProgramPlan();
    }
}
