package org.apache.flink.graph.gsa;

import java.util.Map;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.CustomUnaryOperation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.library.TriangleEnumerator;
import org.apache.flink.graph.utils.GraphUtils;
import org.apache.flink.types.LongValue;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/graph/gsa/GatherSumApplyIteration.class */
public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperation<Vertex<K, VV>, Vertex<K, VV>> {
    private DataSet<Vertex<K, VV>> vertexDataSet;
    private DataSet<Edge<K, EV>> edgeDataSet;
    private final GatherFunction<VV, EV, M> gather;
    private final SumFunction<VV, EV, M> sum;
    private final ApplyFunction<K, VV, M> apply;
    private final int maximumNumberOfIterations;
    private EdgeDirection direction = EdgeDirection.OUT;
    private GSAConfiguration configuration;

    /* renamed from: org.apache.flink.graph.gsa.GatherSumApplyIteration$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/flink/graph/gsa/GatherSumApplyIteration$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$flink$graph$EdgeDirection = new int[EdgeDirection.values().length];

        static {
            try {
                $SwitchMap$org$apache$flink$graph$EdgeDirection[EdgeDirection.OUT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$flink$graph$EdgeDirection[EdgeDirection.IN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$flink$graph$EdgeDirection[EdgeDirection.ALL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/graph/gsa/GatherSumApplyIteration$ApplyUdf.class */
    private static final class ApplyUdf<K, VV, EV, M> extends RichFlatJoinFunction<Tuple2<K, M>, Vertex<K, VV>, Vertex<K, VV>> implements ResultTypeQueryable<Vertex<K, VV>> {
        private final ApplyFunction<K, VV, M> applyFunction;
        private transient TypeInformation<Vertex<K, VV>> resultType;

        private ApplyUdf(ApplyFunction<K, VV, M> applyFunction, TypeInformation<Vertex<K, VV>> typeInformation) {
            this.applyFunction = applyFunction;
            this.resultType = typeInformation;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void join(Tuple2<K, M> tuple2, Vertex<K, VV> vertex, Collector<Vertex<K, VV>> collector) throws Exception {
            this.applyFunction.setOutput(vertex, collector);
            this.applyFunction.apply(tuple2.f1, vertex.getValue());
        }

        public void open(Configuration configuration) throws Exception {
            if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
                this.applyFunction.setNumberOfVertices(((LongValue) getRuntimeContext().getBroadcastVariable("number of vertices").iterator().next()).getValue());
            }
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                this.applyFunction.init(getIterationRuntimeContext());
            }
            this.applyFunction.preSuperstep();
        }

        public void close() throws Exception {
            this.applyFunction.postSuperstep();
        }

        public TypeInformation<Vertex<K, VV>> getProducedType() {
            return this.resultType;
        }

        /* synthetic */ ApplyUdf(ApplyFunction applyFunction, TypeInformation typeInformation, AnonymousClass1 anonymousClass1) {
            this(applyFunction, typeInformation);
        }
    }

    @FunctionAnnotation.ForwardedFields({"f0"})
    /* loaded from: input_file:org/apache/flink/graph/gsa/GatherSumApplyIteration$GatherUdf.class */
    private static final class GatherUdf<K, VV, EV, M> extends RichMapFunction<Tuple2<K, Neighbor<VV, EV>>, Tuple2<K, M>> implements ResultTypeQueryable<Tuple2<K, M>> {
        private final GatherFunction<VV, EV, M> gatherFunction;
        private transient TypeInformation<Tuple2<K, M>> resultType;

        private GatherUdf(GatherFunction<VV, EV, M> gatherFunction, TypeInformation<Tuple2<K, M>> typeInformation) {
            this.gatherFunction = gatherFunction;
            this.resultType = typeInformation;
        }

        public Tuple2<K, M> map(Tuple2<K, Neighbor<VV, EV>> tuple2) {
            return new Tuple2<>(tuple2.f0, this.gatherFunction.gather((Neighbor) tuple2.f1));
        }

        public void open(Configuration configuration) throws Exception {
            if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
                this.gatherFunction.setNumberOfVertices(((LongValue) getRuntimeContext().getBroadcastVariable("number of vertices").iterator().next()).getValue());
            }
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                this.gatherFunction.init(getIterationRuntimeContext());
            }
            this.gatherFunction.preSuperstep();
        }

        public void close() throws Exception {
            this.gatherFunction.postSuperstep();
        }

        public TypeInformation<Tuple2<K, M>> getProducedType() {
            return this.resultType;
        }

        /* synthetic */ GatherUdf(GatherFunction gatherFunction, TypeInformation typeInformation, AnonymousClass1 anonymousClass1) {
            this(gatherFunction, typeInformation);
        }
    }

    @FunctionAnnotation.ForwardedFieldsSecond({"f0"})
    /* loaded from: input_file:org/apache/flink/graph/gsa/GatherSumApplyIteration$ProjectKeyWithNeighborIN.class */
    private static final class ProjectKeyWithNeighborIN<K, VV, EV> implements FlatJoinFunction<Vertex<K, VV>, Edge<K, EV>, Tuple2<K, Neighbor<VV, EV>>> {
        private ProjectKeyWithNeighborIN() {
        }

        public void join(Vertex<K, VV> vertex, Edge<K, EV> edge, Collector<Tuple2<K, Neighbor<VV, EV>>> collector) {
            collector.collect(new Tuple2(edge.getSource(), new Neighbor(vertex.getValue(), edge.getValue())));
        }

        /* synthetic */ ProjectKeyWithNeighborIN(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    @FunctionAnnotation.ForwardedFieldsSecond({"f1->f0"})
    /* loaded from: input_file:org/apache/flink/graph/gsa/GatherSumApplyIteration$ProjectKeyWithNeighborOUT.class */
    private static final class ProjectKeyWithNeighborOUT<K, VV, EV> implements FlatJoinFunction<Vertex<K, VV>, Edge<K, EV>, Tuple2<K, Neighbor<VV, EV>>> {
        private ProjectKeyWithNeighborOUT() {
        }

        public void join(Vertex<K, VV> vertex, Edge<K, EV> edge, Collector<Tuple2<K, Neighbor<VV, EV>>> collector) {
            collector.collect(new Tuple2(edge.getTarget(), new Neighbor(vertex.getValue(), edge.getValue())));
        }

        /* synthetic */ ProjectKeyWithNeighborOUT(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* loaded from: input_file:org/apache/flink/graph/gsa/GatherSumApplyIteration$SumUdf.class */
    private static final class SumUdf<K, VV, EV, M> extends RichReduceFunction<Tuple2<K, M>> implements ResultTypeQueryable<Tuple2<K, M>> {
        private final SumFunction<VV, EV, M> sumFunction;
        private transient TypeInformation<Tuple2<K, M>> resultType;

        private SumUdf(SumFunction<VV, EV, M> sumFunction, TypeInformation<Tuple2<K, M>> typeInformation) {
            this.sumFunction = sumFunction;
            this.resultType = typeInformation;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public Tuple2<K, M> reduce(Tuple2<K, M> tuple2, Tuple2<K, M> tuple22) throws Exception {
            Object sum = this.sumFunction.sum(tuple2.f1, tuple22.f1);
            if (sum == tuple22.f1) {
                Object obj = tuple22.f1;
                tuple22.f1 = tuple2.f1;
                tuple2.f1 = obj;
            } else {
                tuple2.f1 = sum;
            }
            return tuple2;
        }

        public void open(Configuration configuration) throws Exception {
            if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
                this.sumFunction.setNumberOfVertices(((LongValue) getRuntimeContext().getBroadcastVariable("number of vertices").iterator().next()).getValue());
            }
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                this.sumFunction.init(getIterationRuntimeContext());
            }
            this.sumFunction.preSuperstep();
        }

        public void close() throws Exception {
            this.sumFunction.postSuperstep();
        }

        public TypeInformation<Tuple2<K, M>> getProducedType() {
            return this.resultType;
        }

        /* synthetic */ SumUdf(SumFunction sumFunction, TypeInformation typeInformation, AnonymousClass1 anonymousClass1) {
            this(sumFunction, typeInformation);
        }
    }

    private GatherSumApplyIteration(GatherFunction<VV, EV, M> gatherFunction, SumFunction<VV, EV, M> sumFunction, ApplyFunction<K, VV, M> applyFunction, DataSet<Edge<K, EV>> dataSet, int i) {
        Preconditions.checkNotNull(gatherFunction);
        Preconditions.checkNotNull(sumFunction);
        Preconditions.checkNotNull(applyFunction);
        Preconditions.checkNotNull(dataSet);
        Preconditions.checkArgument(i > 0, "The maximum number of iterations must be at least one.");
        this.gather = gatherFunction;
        this.sum = sumFunction;
        this.apply = applyFunction;
        this.edgeDataSet = dataSet;
        this.maximumNumberOfIterations = i;
    }

    public void setInput(DataSet<Vertex<K, VV>> dataSet) {
        this.vertexDataSet = dataSet;
    }

    public DataSet<Vertex<K, VV>> createResult() {
        JoinOperator.EquiJoin with;
        if (this.vertexDataSet == null) {
            throw new IllegalStateException("The input data set has not been set.");
        }
        TupleTypeInfo tupleTypeInfo = new TupleTypeInfo(new TypeInformation[]{this.vertexDataSet.getType().getTypeAt(0), TypeExtractor.createTypeInfo(this.gather, GatherFunction.class, this.gather.getClass(), 2)});
        TypeInformation type = this.vertexDataSet.getType();
        DataSet<LongValue> dataSet = null;
        if (this.configuration != null && this.configuration.isOptNumVertices()) {
            try {
                dataSet = GraphUtils.count(this.vertexDataSet);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        GatherUdf gatherUdf = new GatherUdf(this.gather, tupleTypeInfo, null);
        SumUdf sumUdf = new SumUdf(this.sum, tupleTypeInfo, null);
        ApplyUdf applyUdf = new ApplyUdf(this.apply, type, null);
        DeltaIteration iterateDelta = this.vertexDataSet.iterateDelta(this.vertexDataSet, this.maximumNumberOfIterations, new int[]{0});
        if (this.configuration != null) {
            iterateDelta.name(this.configuration.getName("Gather-sum-apply iteration (" + this.gather + " | " + this.sum + " | " + this.apply + ")"));
            iterateDelta.parallelism(this.configuration.getParallelism());
            iterateDelta.setSolutionSetUnManaged(this.configuration.isSolutionSetUnmanagedMemory());
            for (Map.Entry<String, Aggregator<?>> entry : this.configuration.getAggregators().entrySet()) {
                iterateDelta.registerAggregator(entry.getKey(), entry.getValue());
            }
        } else {
            iterateDelta.name("Gather-sum-apply iteration (" + this.gather + " | " + this.sum + " | " + this.apply + ")");
        }
        if (this.configuration != null) {
            this.direction = this.configuration.getDirection();
        }
        switch (AnonymousClass1.$SwitchMap$org$apache$flink$graph$EdgeDirection[this.direction.ordinal()]) {
            case 1:
                with = iterateDelta.getWorkset().join(this.edgeDataSet).where(new int[]{0}).equalTo(new int[]{0}).with(new ProjectKeyWithNeighborOUT(null));
                break;
            case 2:
                with = iterateDelta.getWorkset().join(this.edgeDataSet).where(new int[]{0}).equalTo(new int[]{1}).with(new ProjectKeyWithNeighborIN(null));
                break;
            case TriangleEnumerator.EdgeWithDegrees.D2 /* 3 */:
                with = iterateDelta.getWorkset().join(this.edgeDataSet).where(new int[]{0}).equalTo(new int[]{0}).with(new ProjectKeyWithNeighborOUT(null)).union(iterateDelta.getWorkset().join(this.edgeDataSet).where(new int[]{0}).equalTo(new int[]{1}).with(new ProjectKeyWithNeighborIN(null)));
                break;
            default:
                with = iterateDelta.getWorkset().join(this.edgeDataSet).where(new int[]{0}).equalTo(new int[]{0}).with(new ProjectKeyWithNeighborOUT(null));
                break;
        }
        MapOperator name = with.map(gatherUdf).name("Gather");
        if (this.configuration != null) {
            for (Tuple2<String, DataSet<?>> tuple2 : this.configuration.getGatherBcastVars()) {
                name = (MapOperator) name.withBroadcastSet((DataSet) tuple2.f1, (String) tuple2.f0);
            }
            if (this.configuration.isOptNumVertices()) {
                name = (MapOperator) name.withBroadcastSet(dataSet, "number of vertices");
            }
        }
        ReduceOperator name2 = name.groupBy(new int[]{0}).reduce(sumUdf).name("Sum");
        if (this.configuration != null) {
            for (Tuple2<String, DataSet<?>> tuple22 : this.configuration.getSumBcastVars()) {
                name2 = (ReduceOperator) name2.withBroadcastSet((DataSet) tuple22.f1, (String) tuple22.f0);
            }
            if (this.configuration.isOptNumVertices()) {
                name2 = (ReduceOperator) name2.withBroadcastSet(dataSet, "number of vertices");
            }
        }
        JoinOperator name3 = name2.join(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).with(applyUdf).name("Apply");
        if (this.configuration != null) {
            for (Tuple2<String, DataSet<?>> tuple23 : this.configuration.getApplyBcastVars()) {
                name3 = (JoinOperator) name3.withBroadcastSet((DataSet) tuple23.f1, (String) tuple23.f0);
            }
            if (this.configuration.isOptNumVertices()) {
                name3 = (JoinOperator) name3.withBroadcastSet(dataSet, "number of vertices");
            }
        }
        name3.withForwardedFieldsFirst(new String[]{"0"}).withForwardedFieldsSecond(new String[]{"0"});
        return iterateDelta.closeWith(name3, name3);
    }

    public static <K, VV, EV, M> GatherSumApplyIteration<K, VV, EV, M> withEdges(DataSet<Edge<K, EV>> dataSet, GatherFunction<VV, EV, M> gatherFunction, SumFunction<VV, EV, M> sumFunction, ApplyFunction<K, VV, M> applyFunction, int i) {
        return new GatherSumApplyIteration<>(gatherFunction, sumFunction, applyFunction, dataSet, i);
    }

    public void configure(GSAConfiguration gSAConfiguration) {
        this.configuration = gSAConfiguration;
    }

    public GSAConfiguration getIterationConfiguration() {
        return this.configuration;
    }
}
