package org.apache.flink.optimizer;

import java.util.Iterator;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.base.JoinOperatorBase;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.FilterOperator;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.optimizer.plan.BulkIterationPlanNode;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.plan.WorksetIterationPlanNode;
import org.apache.flink.optimizer.plandump.PlanJSONDumpGenerator;
import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
import org.apache.flink.optimizer.testfunctions.IdentityKeyExtractor;
import org.apache.flink.optimizer.testfunctions.IdentityMapper;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest.class */
public class IterationsCompilerTest extends CompilerTestBase {

    /* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest$DummyMap.class */
    public static final class DummyMap extends RichMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> {
        public Tuple2<Long, Long> map(Tuple2<Long, Long> tuple2) throws Exception {
            return tuple2;
        }
    }

    @FunctionAnnotation.ForwardedFields({"0"})
    /* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest$DuplicateValue.class */
    public static final class DuplicateValue extends RichMapFunction<Tuple1<Long>, Tuple2<Long, Long>> {
        public Tuple2<Long, Long> map(Tuple1<Long> tuple1) throws Exception {
            return new Tuple2<>(tuple1.f0, tuple1.f0);
        }
    }

    /* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest$DuplicateValueScalar.class */
    public static final class DuplicateValueScalar<T> extends RichMapFunction<T, Tuple2<T, T>> {
        public Tuple2<T, T> map(T t) {
            return new Tuple2<>(t, t);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* renamed from: map, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m1map(Object obj) throws Exception {
            return map((DuplicateValueScalar<T>) obj);
        }
    }

    /* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest$FlatMapJoin.class */
    public static final class FlatMapJoin extends RichFlatMapFunction<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> {
        public void flatMap(Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> tuple2, Collector<Tuple2<Long, Long>> collector) {
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>) obj, (Collector<Tuple2<Long, Long>>) collector);
        }
    }

    /* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest$IdFilter.class */
    public static final class IdFilter<T> implements FilterFunction<T> {
        public boolean filter(T t) {
            return true;
        }
    }

    /* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest$Join222.class */
    public static final class Join222 extends RichJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
        public Tuple2<Long, Long> join(Tuple2<Long, Long> tuple2, Tuple2<Long, Long> tuple22) {
            return null;
        }
    }

    @FunctionAnnotation.ForwardedFields({"0"})
    /* loaded from: input_file:org/apache/flink/optimizer/IterationsCompilerTest$Reduce101.class */
    public static final class Reduce101 extends RichGroupReduceFunction<Tuple1<Long>, Tuple1<Long>> {
        public void reduce(Iterable<Tuple1<Long>> iterable, Collector<Tuple1<Long>> collector) {
        }
    }

    @Test
    public void testSolutionSetDeltaDependsOnBroadcastVariable() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            MapOperator map = executionEnvironment.generateSequence(1L, 1000L).map(new DuplicateValueScalar());
            MapOperator map2 = executionEnvironment.generateSequence(1L, 1000L).map(new DuplicateValueScalar());
            DeltaIteration iterateDelta = map.iterateDelta(map, 1000, new int[]{1});
            JoinOperator.ProjectJoin projectSecond = map2.map(new IdentityMapper()).withBroadcastSet(iterateDelta.getWorkset(), "bc data").join(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{1}).projectFirst(new int[]{1}).projectSecond(new int[]{1});
            iterateDelta.closeWith(projectSecond.map(new IdentityMapper()), projectSecond).print();
            OptimizedPlan compileNoStats = compileNoStats(executionEnvironment.createProgramPlan());
            new PlanJSONDumpGenerator().getOptimizerPlanAsJSON(compileNoStats);
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testTwoIterationsWithMapperInbetween() throws Exception {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(8);
            DataSource fromElements = executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1L, 2L)});
            DataSource fromElements2 = executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1L, 2L)});
            doDeltaIteration(doBulkIteration(fromElements, fromElements2).map(new DummyMap()), fromElements2).print();
            OptimizedPlan compileNoStats = compileNoStats(executionEnvironment.createProgramPlan());
            Assert.assertEquals(1L, compileNoStats.getDataSinks().size());
            Assert.assertTrue(((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource() instanceof WorksetIterationPlanNode);
            WorksetIterationPlanNode source = ((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource();
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, source.getInput1().getShipStrategy());
            Assert.assertTrue(source.getInput2().getTempMode().breaksPipeline());
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testTwoIterationsDirectlyChained() throws Exception {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(8);
            DataSource fromElements = executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1L, 2L)});
            DataSource fromElements2 = executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1L, 2L)});
            doDeltaIteration(doBulkIteration(fromElements, fromElements2), fromElements2).print();
            OptimizedPlan compileNoStats = compileNoStats(executionEnvironment.createProgramPlan());
            Assert.assertEquals(1L, compileNoStats.getDataSinks().size());
            Assert.assertTrue(((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource() instanceof WorksetIterationPlanNode);
            WorksetIterationPlanNode source = ((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource();
            Assert.assertEquals(ShipStrategyType.PARTITION_HASH, source.getInput1().getShipStrategy());
            Assert.assertTrue(source.getInput2().getTempMode().breaksPipeline());
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testTwoWorksetIterationsDirectlyChained() throws Exception {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(8);
            DataSource fromElements = executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1L, 2L)});
            DataSource fromElements2 = executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1L, 2L)});
            doDeltaIteration(doDeltaIteration(fromElements, fromElements2), fromElements2).print();
            OptimizedPlan compileNoStats = compileNoStats(executionEnvironment.createProgramPlan());
            Assert.assertEquals(1L, compileNoStats.getDataSinks().size());
            Assert.assertTrue(((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource() instanceof WorksetIterationPlanNode);
            WorksetIterationPlanNode source = ((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource();
            Assert.assertEquals(ShipStrategyType.FORWARD, source.getInput1().getShipStrategy());
            Assert.assertTrue(source.getInput2().getTempMode().breaksPipeline());
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testIterationPushingWorkOut() throws Exception {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(8);
            doBulkIteration(executionEnvironment.readCsvFile("/some/file/path").types(Long.class).map(new DuplicateValue()), executionEnvironment.readCsvFile("/some/file/path").types(Long.class, Long.class)).print();
            OptimizedPlan compileNoStats = compileNoStats(executionEnvironment.createProgramPlan());
            Assert.assertEquals(1L, compileNoStats.getDataSinks().size());
            Assert.assertTrue(((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource() instanceof BulkIterationPlanNode);
            Iterator it = ((SinkPlanNode) compileNoStats.getDataSinks().iterator().next()).getInput().getSource().getPartialSolutionPlanNode().getOutgoingChannels().iterator();
            while (it.hasNext()) {
                Assert.assertEquals(ShipStrategyType.PARTITION_HASH, ((Channel) it.next()).getShipStrategy());
            }
            new JobGraphGenerator().compileJobGraph(compileNoStats);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testWorksetIterationPipelineBreakerPlacement() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(8);
            MapOperator map = executionEnvironment.readCsvFile("/some/file/path").types(Long.class).map(new DuplicateValue());
            DeltaIteration iterateDelta = executionEnvironment.readCsvFile("/some/file/path").types(Long.class).map(new DuplicateValue()).iterateDelta(map, 100, new int[]{0});
            MapOperator map2 = iterateDelta.getWorkset().map(new IdentityMapper());
            map.join(iterateDelta.closeWith(map2, map2), JoinOperatorBase.JoinHint.REPARTITION_HASH_FIRST).where(new int[]{0}).equalTo(new int[]{0}).print();
            compileNoStats(executionEnvironment.createProgramPlan());
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testResetPartialSolution() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            DataSource generateSequence = executionEnvironment.generateSequence(1L, 10L);
            DataSource generateSequence2 = executionEnvironment.generateSequence(1L, 10L);
            IterativeDataSet iterate = generateSequence.union(generateSequence2).union(executionEnvironment.generateSequence(1L, 10L)).iterate(10);
            FilterOperator filter = iterate.filter(new IdFilter());
            FilterOperator filter2 = iterate.filter(new IdFilter());
            FilterOperator filter3 = iterate.filter(new IdFilter());
            iterate.closeWith(filter.union(filter2.map(new RichMapFunction<Long, Long>() { // from class: org.apache.flink.optimizer.IterationsCompilerTest.2
                public Long map(Long l) {
                    return null;
                }
            }).withBroadcastSet(filter.map(new IdentityMapper()).join(filter3).where(new IdentityKeyExtractor()).equalTo(new IdentityKeyExtractor()).with(new JoinFunction<Long, Long, Long>() { // from class: org.apache.flink.optimizer.IterationsCompilerTest.1
                public Long join(Long l, Long l2) {
                    return null;
                }
            }), "some-name")).union(filter3)).print();
            new JobGraphGenerator().compileJobGraph(compileNoStats(executionEnvironment.createProgramPlan()));
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    public static DataSet<Tuple2<Long, Long>> doBulkIteration(DataSet<Tuple2<Long, Long>> dataSet, DataSet<Tuple2<Long, Long>> dataSet2) {
        IterativeDataSet iterate = dataSet.iterate(20);
        return iterate.closeWith(iterate.join(dataSet2).where(new int[]{0}).equalTo(new int[]{0}).with(new Join222()).groupBy(new int[]{0}).aggregate(Aggregations.MIN, 1).join(iterate).where(new int[]{0}).equalTo(new int[]{0}).flatMap(new FlatMapJoin()));
    }

    public static DataSet<Tuple2<Long, Long>> doDeltaIteration(DataSet<Tuple2<Long, Long>> dataSet, DataSet<Tuple2<Long, Long>> dataSet2) {
        DeltaIteration iterateDelta = dataSet.iterateDelta(dataSet, 100, new int[]{0});
        FlatMapOperator flatMap = iterateDelta.getWorkset().join(dataSet2).where(new int[]{0}).equalTo(new int[]{0}).projectSecond(new int[]{1}).groupBy(new int[]{0}).reduceGroup(new Reduce101()).join(dataSet2).where(new int[]{0}).equalTo(new int[]{1}).projectSecond(new int[]{0, 1}).join(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).with(new Join222()).groupBy(new int[]{0}).aggregate(Aggregations.MIN, 1).join(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).flatMap(new FlatMapJoin());
        return iterateDelta.closeWith(flatMap, flatMap);
    }
}
