package org.apache.flink.test.compiler.iterations;

import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.compiler.plan.DualInputPlanNode;
import org.apache.flink.compiler.plan.OptimizedPlan;
import org.apache.flink.compiler.plan.SolutionSetPlanNode;
import org.apache.flink.compiler.plantranslate.NepheleJobGraphGenerator;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.test.compiler.util.CompilerTestBase;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/test/compiler/iterations/MultipleJoinsWithSolutionSetCompilerTest.class */
public class MultipleJoinsWithSolutionSetCompilerTest extends CompilerTestBase {
    private static final String JOIN_1 = "join1";
    private static final String JOIN_2 = "join2";

    /* loaded from: input_file:org/apache/flink/test/compiler/iterations/MultipleJoinsWithSolutionSetCompilerTest$Duplicator.class */
    public static final class Duplicator extends RichFlatMapFunction<Tuple2<Long, Double>, Tuple2<Long, Double>> {
        public void flatMap(Tuple2<Long, Double> tuple2, Collector<Tuple2<Long, Double>> collector) {
            collector.collect(tuple2);
            collector.collect(tuple2);
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Long, Double>) obj, (Collector<Tuple2<Long, Double>>) collector);
        }
    }

    /* loaded from: input_file:org/apache/flink/test/compiler/iterations/MultipleJoinsWithSolutionSetCompilerTest$Expander.class */
    public static final class Expander extends RichMapFunction<Tuple2<Long, Double>, Tuple3<Long, Double, Double>> {
        public Tuple3<Long, Double, Double> map(Tuple2<Long, Double> tuple2) {
            return new Tuple3<>(tuple2.f0, tuple2.f1, Double.valueOf(((Double) tuple2.f1).doubleValue() * 2.0d));
        }
    }

    /* loaded from: input_file:org/apache/flink/test/compiler/iterations/MultipleJoinsWithSolutionSetCompilerTest$SummingJoin.class */
    public static final class SummingJoin extends RichJoinFunction<Tuple2<Long, Double>, Tuple2<Long, Double>, Tuple2<Long, Double>> {
        public Tuple2<Long, Double> join(Tuple2<Long, Double> tuple2, Tuple2<Long, Double> tuple22) {
            return new Tuple2<>(tuple2.f0, Double.valueOf(((Double) tuple2.f1).doubleValue() + ((Double) tuple22.f1).doubleValue()));
        }
    }

    /* loaded from: input_file:org/apache/flink/test/compiler/iterations/MultipleJoinsWithSolutionSetCompilerTest$SummingJoinProject.class */
    public static final class SummingJoinProject extends RichJoinFunction<Tuple3<Long, Double, Double>, Tuple2<Long, Double>, Tuple2<Long, Double>> {
        public Tuple2<Long, Double> join(Tuple3<Long, Double, Double> tuple3, Tuple2<Long, Double> tuple2) {
            return new Tuple2<>(tuple3.f0, Double.valueOf(((Double) tuple3.f1).doubleValue() + ((Double) tuple3.f2).doubleValue() + ((Double) tuple2.f1).doubleValue()));
        }
    }

    @Test
    public void testMultiSolutionSetJoinPlan() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            DataSet<Tuple2<Long, Double>> constructPlan = constructPlan(executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1L, Double.valueOf(1.0d))}), 10);
            constructPlan.print();
            constructPlan.print();
            OptimizedPlan compileNoStats = compileNoStats(executionEnvironment.createProgramPlan());
            CompilerTestBase.OptimizerPlanNodeResolver optimizerPlanNodeResolver = getOptimizerPlanNodeResolver(compileNoStats);
            DualInputPlanNode node = optimizerPlanNodeResolver.getNode(JOIN_1);
            DualInputPlanNode node2 = optimizerPlanNodeResolver.getNode(JOIN_2);
            Assert.assertEquals(DriverStrategy.HYBRIDHASH_BUILD_FIRST, node.getDriverStrategy());
            Assert.assertEquals(DriverStrategy.HYBRIDHASH_BUILD_SECOND, node2.getDriverStrategy());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node.getInput2().getShipStrategy());
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node2.getInput1().getShipStrategy());
            Assert.assertEquals(SolutionSetPlanNode.class, node.getInput1().getSource().getClass());
            Assert.assertEquals(SolutionSetPlanNode.class, node2.getInput2().getSource().getClass());
            new NepheleJobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail("Test erroneous: " + e.getMessage());
        }
    }

    public static DataSet<Tuple2<Long, Double>> constructPlan(DataSet<Tuple2<Long, Double>> dataSet, int i) {
        DeltaIteration iterateDelta = dataSet.iterateDelta(dataSet, i, new int[]{0});
        Operator name = iterateDelta.getSolutionSet().join(iterateDelta.getWorkset().flatMap(new Duplicator())).where(new int[]{0}).equalTo(new int[]{0}).with(new SummingJoin()).name(JOIN_1).groupBy(new int[]{0}).aggregate(Aggregations.MIN, 1).map(new Expander()).join(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).with(new SummingJoinProject()).name(JOIN_2);
        return iterateDelta.closeWith(name, name.groupBy(new int[]{0}).aggregate(Aggregations.SUM, 1));
    }
}
