package org.apache.flink.test.iterative.aggregators;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.LongValue;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.class */
public class AggregatorConvergenceITCase extends MultipleProgramsTestBase {
    final List<Tuple2<Long, Long>> verticesInput;
    final List<Tuple2<Long, Long>> edgesInput;
    final List<Tuple2<Long, Long>> expectedResult;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase$LongSumAggregatorWithParameter.class */
    public static final class LongSumAggregatorWithParameter extends LongSumAggregator {
        private long componentId;

        public LongSumAggregatorWithParameter(long j) {
            this.componentId = j;
        }

        public long getComponentId() {
            return this.componentId;
        }
    }

    /* loaded from: input_file:org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase$MinimumIdFilter.class */
    private static class MinimumIdFilter extends RichFlatMapFunction<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> {
        private final String aggName;
        private LongSumAggregator aggr;

        public MinimumIdFilter(String str) {
            this.aggName = str;
        }

        public void open(Configuration configuration) {
            this.aggr = getIterationRuntimeContext().getIterationAggregator(this.aggName);
        }

        public void flatMap(Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> tuple2, Collector<Tuple2<Long, Long>> collector) {
            if (((Long) ((Tuple2) tuple2.f0).f1).longValue() >= ((Long) ((Tuple2) tuple2.f1).f1).longValue()) {
                collector.collect(tuple2.f1);
            } else {
                collector.collect(tuple2.f0);
                this.aggr.aggregate(1L);
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>) obj, (Collector<Tuple2<Long, Long>>) collector);
        }
    }

    /* loaded from: input_file:org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase$MinimumIdFilterCounting.class */
    private static final class MinimumIdFilterCounting extends RichFlatMapFunction<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> {
        private static final long[] aggr_value = new long[5];
        private final String aggName;
        private LongSumAggregatorWithParameter aggr;

        public MinimumIdFilterCounting(String str) {
            this.aggName = str;
        }

        public void open(Configuration configuration) {
            int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
            this.aggr = getIterationRuntimeContext().getIterationAggregator(this.aggName);
            if (superstepNumber <= 1 || getIterationRuntimeContext().getIndexOfThisSubtask() != 0) {
                return;
            }
            aggr_value[superstepNumber - 2] = getIterationRuntimeContext().getPreviousIterationAggregate(this.aggName).getValue();
        }

        public void flatMap(Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> tuple2, Collector<Tuple2<Long, Long>> collector) {
            if (((Long) ((Tuple2) tuple2.f0).f1).longValue() < ((Long) ((Tuple2) tuple2.f1).f1).longValue()) {
                collector.collect(tuple2.f0);
                if (((Long) ((Tuple2) tuple2.f0).f1).longValue() == this.aggr.getComponentId()) {
                    this.aggr.aggregate(1L);
                    return;
                }
                return;
            }
            collector.collect(tuple2.f1);
            if (((Long) ((Tuple2) tuple2.f1).f1).longValue() == this.aggr.getComponentId()) {
                this.aggr.aggregate(1L);
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>) obj, (Collector<Tuple2<Long, Long>>) collector);
        }
    }

    /* loaded from: input_file:org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase$NeighborWithComponentIDJoin.class */
    private static final class NeighborWithComponentIDJoin extends RichJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
        private static final long serialVersionUID = 1;

        private NeighborWithComponentIDJoin() {
        }

        public Tuple2<Long, Long> join(Tuple2<Long, Long> tuple2, Tuple2<Long, Long> tuple22) {
            tuple2.f0 = tuple22.f1;
            return tuple2;
        }
    }

    /* loaded from: input_file:org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase$UpdatedElementsConvergenceCriterion.class */
    private static class UpdatedElementsConvergenceCriterion implements ConvergenceCriterion<LongValue> {
        private final long threshold;

        public UpdatedElementsConvergenceCriterion(long j) {
            this.threshold = j;
        }

        public boolean isConverged(int i, LongValue longValue) {
            return longValue.getValue() < this.threshold;
        }
    }

    public AggregatorConvergenceITCase(MultipleProgramsTestBase.TestExecutionMode testExecutionMode) {
        super(testExecutionMode);
        this.verticesInput = Arrays.asList(new Tuple2(1L, 1L), new Tuple2(2L, 2L), new Tuple2(3L, 3L), new Tuple2(4L, 4L), new Tuple2(5L, 5L), new Tuple2(6L, 6L), new Tuple2(7L, 7L), new Tuple2(8L, 8L), new Tuple2(9L, 9L));
        this.edgesInput = Arrays.asList(new Tuple2(1L, 2L), new Tuple2(1L, 3L), new Tuple2(2L, 3L), new Tuple2(2L, 4L), new Tuple2(2L, 1L), new Tuple2(3L, 1L), new Tuple2(3L, 2L), new Tuple2(4L, 2L), new Tuple2(4L, 6L), new Tuple2(5L, 6L), new Tuple2(6L, 4L), new Tuple2(6L, 5L), new Tuple2(7L, 8L), new Tuple2(7L, 9L), new Tuple2(8L, 7L), new Tuple2(8L, 9L), new Tuple2(9L, 7L), new Tuple2(9L, 8L));
        this.expectedResult = Arrays.asList(new Tuple2(1L, 1L), new Tuple2(2L, 1L), new Tuple2(3L, 1L), new Tuple2(4L, 1L), new Tuple2(5L, 2L), new Tuple2(6L, 1L), new Tuple2(7L, 7L), new Tuple2(8L, 7L), new Tuple2(9L, 7L));
    }

    @Test
    public void testConnectedComponentsWithParametrizableConvergence() throws Exception {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        DataSource fromCollection = executionEnvironment.fromCollection(this.verticesInput);
        DataSource fromCollection2 = executionEnvironment.fromCollection(this.edgesInput);
        IterativeDataSet iterate = fromCollection.iterate(10);
        iterate.registerAggregationConvergenceCriterion("updated.elements.aggr", new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(3L));
        List collect = iterate.closeWith(iterate.join(fromCollection2).where(new int[]{0}).equalTo(new int[]{0}).with(new NeighborWithComponentIDJoin()).groupBy(new int[]{0}).min(1).join(iterate).where(new int[]{0}).equalTo(new int[]{0}).flatMap(new MinimumIdFilter("updated.elements.aggr"))).collect();
        Collections.sort(collect, new TestBaseUtils.TupleComparator());
        Assert.assertEquals(this.expectedResult, collect);
    }

    @Test
    public void testDeltaConnectedComponentsWithParametrizableConvergence() throws Exception {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        DataSource fromCollection = executionEnvironment.fromCollection(this.verticesInput);
        DataSource fromCollection2 = executionEnvironment.fromCollection(this.edgesInput);
        DeltaIteration iterateDelta = fromCollection.iterateDelta(fromCollection, 10, new int[]{0});
        iterateDelta.registerAggregationConvergenceCriterion("updated.elements.aggr", new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(3L));
        FlatMapOperator flatMap = iterateDelta.getWorkset().join(fromCollection2).where(new int[]{0}).equalTo(new int[]{0}).with(new NeighborWithComponentIDJoin()).groupBy(new int[]{0}).min(1).join(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).flatMap(new MinimumIdFilter("updated.elements.aggr"));
        List collect = iterateDelta.closeWith(flatMap, flatMap).collect();
        Collections.sort(collect, new TestBaseUtils.TupleComparator());
        Assert.assertEquals(this.expectedResult, collect);
    }

    @Test
    public void testParameterizableAggregator() throws Exception {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        DataSource fromCollection = executionEnvironment.fromCollection(this.verticesInput);
        DataSource fromCollection2 = executionEnvironment.fromCollection(this.edgesInput);
        IterativeDataSet iterate = fromCollection.iterate(5);
        iterate.registerAggregator("elements.in.component.aggregator", new LongSumAggregatorWithParameter(1L));
        List collect = iterate.closeWith(iterate.join(fromCollection2).where(new int[]{0}).equalTo(new int[]{0}).with(new NeighborWithComponentIDJoin()).groupBy(new int[]{0}).min(1).join(iterate).where(new int[]{0}).equalTo(new int[]{0}).flatMap(new MinimumIdFilterCounting("elements.in.component.aggregator"))).collect();
        Collections.sort(collect, new TestBaseUtils.TupleComparator());
        Assert.assertEquals(Arrays.asList(new Tuple2(1L, 1L), new Tuple2(2L, 1L), new Tuple2(3L, 1L), new Tuple2(4L, 1L), new Tuple2(5L, 1L), new Tuple2(6L, 1L), new Tuple2(7L, 7L), new Tuple2(8L, 7L), new Tuple2(9L, 7L)), collect);
        long[] jArr = MinimumIdFilterCounting.aggr_value;
        Assert.assertEquals(3L, jArr[0]);
        Assert.assertEquals(4L, jArr[1]);
        Assert.assertEquals(5L, jArr[2]);
        Assert.assertEquals(6L, jArr[3]);
    }
}
