package org.apache.flink.graph.spargel;

import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.CoGroupOperator;
import org.apache.flink.api.java.operators.CustomUnaryOperation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.TwoInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.utils.GraphUtils;
import org.apache.flink.types.LongValue;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/graph/spargel/ScatterGatherIteration.class */
public class ScatterGatherIteration<K, VV, Message, EV> implements CustomUnaryOperation<Vertex<K, VV>, Vertex<K, VV>> {
    private final ScatterFunction<K, VV, Message, EV> scatterFunction;
    private final GatherFunction<K, VV, Message> gatherFunction;
    private final DataSet<Edge<K, EV>> edgesWithValue;
    private final int maximumNumberOfIterations;
    private final TypeInformation<Message> messageType;
    private DataSet<Vertex<K, VV>> initialVertices;
    private ScatterGatherConfiguration configuration;

    /* loaded from: input_file:org/apache/flink/graph/spargel/ScatterGatherIteration$GatherUdf.class */
    private static abstract class GatherUdf<K, VVWithDegrees, Message> extends RichCoGroupFunction<Tuple2<K, Message>, Vertex<K, VVWithDegrees>, Vertex<K, VVWithDegrees>> implements ResultTypeQueryable<Vertex<K, VVWithDegrees>> {
        private static final long serialVersionUID = 1;
        final GatherFunction<K, VVWithDegrees, Message> gatherFunction;
        final MessageIterator<Message> messageIter;
        private transient TypeInformation<Vertex<K, VVWithDegrees>> resultType;

        private GatherUdf(GatherFunction<K, VVWithDegrees, Message> gatherFunction, TypeInformation<Vertex<K, VVWithDegrees>> typeInformation) {
            this.messageIter = new MessageIterator<>();
            this.gatherFunction = gatherFunction;
            this.resultType = typeInformation;
        }

        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void open(Configuration configuration) throws Exception {
            if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
                this.gatherFunction.setNumberOfVertices(((LongValue) getRuntimeContext().getBroadcastVariable("number of vertices").iterator().next()).getValue());
            }
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                this.gatherFunction.init(getIterationRuntimeContext());
            }
            this.gatherFunction.preSuperstep();
        }

        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void close() throws Exception {
            this.gatherFunction.postSuperstep();
        }

        @Override // org.apache.flink.api.java.typeutils.ResultTypeQueryable
        public TypeInformation<Vertex<K, VVWithDegrees>> getProducedType() {
            return this.resultType;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/graph/spargel/ScatterGatherIteration$GatherUdfSimpleVV.class */
    public static final class GatherUdfSimpleVV<K, VV, Message> extends GatherUdf<K, VV, Message> {
        private GatherUdfSimpleVV(GatherFunction<K, VV, Message> gatherFunction, TypeInformation<Vertex<K, VV>> typeInformation) {
            super(gatherFunction, typeInformation);
        }

        @Override // org.apache.flink.api.common.functions.RichCoGroupFunction, org.apache.flink.api.common.functions.CoGroupFunction
        public void coGroup(Iterable<Tuple2<K, Message>> iterable, Iterable<Vertex<K, VV>> iterable2, Collector<Vertex<K, VV>> collector) throws Exception {
            Iterator<Vertex<K, VV>> it = iterable2.iterator();
            if (it.hasNext()) {
                Vertex<K, VV> next = it.next();
                this.messageIter.setSource(iterable.iterator());
                this.gatherFunction.setOutput(next, collector);
                this.gatherFunction.updateVertex(next, this.messageIter);
                return;
            }
            Iterator<Tuple2<K, Message>> it2 = iterable.iterator();
            if (!it2.hasNext()) {
                throw new Exception();
            }
            String str = "Target vertex does not exist!.";
            try {
                str = "Target vertex '" + it2.next().f0 + "' does not exist!.";
            } catch (Throwable th) {
            }
            throw new Exception(str);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/graph/spargel/ScatterGatherIteration$GatherUdfVVWithDegrees.class */
    public static final class GatherUdfVVWithDegrees<K, VV, Message> extends GatherUdf<K, Tuple3<VV, LongValue, LongValue>, Message> {
        private GatherUdfVVWithDegrees(GatherFunction<K, Tuple3<VV, LongValue, LongValue>, Message> gatherFunction, TypeInformation<Vertex<K, Tuple3<VV, LongValue, LongValue>>> typeInformation) {
            super(gatherFunction, typeInformation);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.flink.api.common.functions.RichCoGroupFunction, org.apache.flink.api.common.functions.CoGroupFunction
        public void coGroup(Iterable<Tuple2<K, Message>> iterable, Iterable<Vertex<K, Tuple3<VV, LongValue, LongValue>>> iterable2, Collector<Vertex<K, Tuple3<VV, LongValue, LongValue>>> collector) throws Exception {
            Iterator<Vertex<K, Tuple3<VV, LongValue, LongValue>>> it = iterable2.iterator();
            if (!it.hasNext()) {
                Iterator<Tuple2<K, Message>> it2 = iterable.iterator();
                if (!it2.hasNext()) {
                    throw new Exception();
                }
                String str = "Target vertex does not exist!.";
                try {
                    str = "Target vertex '" + it2.next().f0 + "' does not exist!.";
                } catch (Throwable th) {
                }
                throw new Exception(str);
            }
            Vertex<K, Tuple3<VV, LongValue, LongValue>> next = it.next();
            this.messageIter.setSource(iterable.iterator());
            this.gatherFunction.setInDegree(((LongValue) ((Tuple3) next.f1).f1).getValue());
            this.gatherFunction.setOutDegree(((LongValue) ((Tuple3) next.f1).f2).getValue());
            this.gatherFunction.setOutputWithDegrees(next, collector);
            this.gatherFunction.updateVertexFromScatterGatherIteration(next, this.messageIter);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/graph/spargel/ScatterGatherIteration$ScatterUdfWithEVsSimpleVV.class */
    public static final class ScatterUdfWithEVsSimpleVV<K, VV, Message, EV> extends ScatterUdfWithEdgeValues<K, VV, VV, Message, EV> {
        private ScatterUdfWithEVsSimpleVV(ScatterFunction<K, VV, Message, EV> scatterFunction, TypeInformation<Tuple2<K, Message>> typeInformation) {
            super(scatterFunction, typeInformation);
        }

        @Override // org.apache.flink.api.common.functions.RichCoGroupFunction, org.apache.flink.api.common.functions.CoGroupFunction
        public void coGroup(Iterable<Edge<K, EV>> iterable, Iterable<Vertex<K, VV>> iterable2, Collector<Tuple2<K, Message>> collector) throws Exception {
            Iterator<Vertex<K, VV>> it = iterable2.iterator();
            if (it.hasNext()) {
                Vertex<K, VV> next = it.next();
                this.scatterFunction.set(iterable.iterator(), collector, next.getId());
                this.scatterFunction.sendMessages(next);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/graph/spargel/ScatterGatherIteration$ScatterUdfWithEVsVVWithDegrees.class */
    public static final class ScatterUdfWithEVsVVWithDegrees<K, VV, Message, EV> extends ScatterUdfWithEdgeValues<K, Tuple3<VV, LongValue, LongValue>, VV, Message, EV> {
        private Vertex<K, VV> nextVertex;

        private ScatterUdfWithEVsVVWithDegrees(ScatterFunction<K, VV, Message, EV> scatterFunction, TypeInformation<Tuple2<K, Message>> typeInformation) {
            super(scatterFunction, typeInformation);
            this.nextVertex = new Vertex<>();
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r1v5, types: [T0, T1] */
        @Override // org.apache.flink.api.common.functions.RichCoGroupFunction, org.apache.flink.api.common.functions.CoGroupFunction
        public void coGroup(Iterable<Edge<K, EV>> iterable, Iterable<Vertex<K, Tuple3<VV, LongValue, LongValue>>> iterable2, Collector<Tuple2<K, Message>> collector) throws Exception {
            Iterator<Vertex<K, Tuple3<VV, LongValue, LongValue>>> it = iterable2.iterator();
            if (it.hasNext()) {
                Vertex<K, Tuple3<VV, LongValue, LongValue>> next = it.next();
                this.nextVertex.f0 = next.f0;
                this.nextVertex.f1 = ((Tuple3) next.f1).f0;
                this.scatterFunction.setInDegree(((LongValue) ((Tuple3) next.f1).f1).getValue());
                this.scatterFunction.setOutDegree(((LongValue) ((Tuple3) next.f1).f2).getValue());
                this.scatterFunction.set(iterable.iterator(), collector, next.getId());
                this.scatterFunction.sendMessages(this.nextVertex);
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/graph/spargel/ScatterGatherIteration$ScatterUdfWithEdgeValues.class */
    private static abstract class ScatterUdfWithEdgeValues<K, VVWithDegrees, VV, Message, EV> extends RichCoGroupFunction<Edge<K, EV>, Vertex<K, VVWithDegrees>, Tuple2<K, Message>> implements ResultTypeQueryable<Tuple2<K, Message>> {
        private static final long serialVersionUID = 1;
        final ScatterFunction<K, VV, Message, EV> scatterFunction;
        private transient TypeInformation<Tuple2<K, Message>> resultType;

        private ScatterUdfWithEdgeValues(ScatterFunction<K, VV, Message, EV> scatterFunction, TypeInformation<Tuple2<K, Message>> typeInformation) {
            this.scatterFunction = scatterFunction;
            this.resultType = typeInformation;
        }

        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void open(Configuration configuration) throws Exception {
            if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
                this.scatterFunction.setNumberOfVertices(((LongValue) getRuntimeContext().getBroadcastVariable("number of vertices").iterator().next()).getValue());
            }
            if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                this.scatterFunction.init(getIterationRuntimeContext());
            }
            this.scatterFunction.preSuperstep();
        }

        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void close() throws Exception {
            this.scatterFunction.postSuperstep();
        }

        @Override // org.apache.flink.api.java.typeutils.ResultTypeQueryable
        public TypeInformation<Tuple2<K, Message>> getProducedType() {
            return this.resultType;
        }
    }

    private ScatterGatherIteration(ScatterFunction<K, VV, Message, EV> scatterFunction, GatherFunction<K, VV, Message> gatherFunction, DataSet<Edge<K, EV>> dataSet, int i) {
        Preconditions.checkNotNull(scatterFunction);
        Preconditions.checkNotNull(gatherFunction);
        Preconditions.checkNotNull(dataSet);
        Preconditions.checkArgument(i > 0, "The maximum number of iterations must be at least one.");
        this.scatterFunction = scatterFunction;
        this.gatherFunction = gatherFunction;
        this.edgesWithValue = dataSet;
        this.maximumNumberOfIterations = i;
        this.messageType = getMessageType(scatterFunction);
    }

    private TypeInformation<Message> getMessageType(ScatterFunction<K, VV, Message, EV> scatterFunction) {
        return TypeExtractor.createTypeInfo(scatterFunction, ScatterFunction.class, scatterFunction.getClass(), 2);
    }

    public void setInput(DataSet<Vertex<K, VV>> dataSet) {
        this.initialVertices = dataSet;
    }

    public DataSet<Vertex<K, VV>> createResult() {
        if (this.initialVertices == null) {
            throw new IllegalStateException("The input data set has not been set.");
        }
        TupleTypeInfo tupleTypeInfo = new TupleTypeInfo(((TupleTypeInfo) this.initialVertices.getType()).getTypeAt(0), this.messageType);
        Graph<K, VV, EV> fromDataSet = Graph.fromDataSet(this.initialVertices, this.edgesWithValue, this.initialVertices.getExecutionEnvironment());
        DataSet<LongValue> dataSet = null;
        if (this.configuration != null && this.configuration.isOptNumVertices()) {
            try {
                dataSet = GraphUtils.count(this.initialVertices);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (this.configuration != null) {
            this.scatterFunction.setDirection(this.configuration.getDirection());
        } else {
            this.scatterFunction.setDirection(EdgeDirection.OUT);
        }
        EdgeDirection direction = this.scatterFunction.getDirection();
        return (this.configuration == null || !this.configuration.isOptDegrees()) ? createResultSimpleVertex(direction, tupleTypeInfo, dataSet) : createResultVerticesWithDegrees(fromDataSet, direction, tupleTypeInfo, dataSet);
    }

    public static <K, VV, Message, EV> ScatterGatherIteration<K, VV, Message, EV> withEdges(DataSet<Edge<K, EV>> dataSet, ScatterFunction<K, VV, Message, EV> scatterFunction, GatherFunction<K, VV, Message> gatherFunction, int i) {
        return new ScatterGatherIteration<>(scatterFunction, gatherFunction, dataSet, i);
    }

    public void configure(ScatterGatherConfiguration scatterGatherConfiguration) {
        this.configuration = scatterGatherConfiguration;
    }

    public ScatterGatherConfiguration getIterationConfiguration() {
        return this.configuration;
    }

    private CoGroupOperator<?, ?, Tuple2<K, Message>> buildScatterFunction(DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> deltaIteration, TypeInformation<Tuple2<K, Message>> typeInformation, int i, int i2, DataSet<LongValue> dataSet) {
        CoGroupOperator<?, ?, Tuple2<K, Message>> name = this.edgesWithValue.coGroup(deltaIteration.getWorkset()).where(new int[]{i}).equalTo(new int[]{i2}).with(new ScatterUdfWithEVsSimpleVV(this.scatterFunction, typeInformation)).name("Messaging");
        if (this.configuration != null) {
            for (Tuple2<String, DataSet<?>> tuple2 : this.configuration.getScatterBcastVars()) {
                name = (CoGroupOperator) name.withBroadcastSet(tuple2.f1, tuple2.f0);
            }
            if (this.configuration.isOptNumVertices()) {
                name = (CoGroupOperator) name.withBroadcastSet(dataSet, "number of vertices");
            }
        }
        return name;
    }

    private CoGroupOperator<?, ?, Tuple2<K, Message>> buildScatterFunctionVerticesWithDegrees(DeltaIteration<Vertex<K, Tuple3<VV, LongValue, LongValue>>, Vertex<K, Tuple3<VV, LongValue, LongValue>>> deltaIteration, TypeInformation<Tuple2<K, Message>> typeInformation, int i, int i2, DataSet<LongValue> dataSet) {
        CoGroupOperator<?, ?, Tuple2<K, Message>> name = this.edgesWithValue.coGroup(deltaIteration.getWorkset()).where(new int[]{i}).equalTo(new int[]{i2}).with(new ScatterUdfWithEVsVVWithDegrees(this.scatterFunction, typeInformation)).name("Messaging");
        if (this.configuration != null) {
            for (Tuple2<String, DataSet<?>> tuple2 : this.configuration.getScatterBcastVars()) {
                name = (CoGroupOperator) name.withBroadcastSet(tuple2.f1, tuple2.f0);
            }
            if (this.configuration.isOptNumVertices()) {
                name = (CoGroupOperator) name.withBroadcastSet(dataSet, "number of vertices");
            }
        }
        return name;
    }

    private void setUpIteration(DeltaIteration<?, ?> deltaIteration) {
        if (this.configuration == null) {
            deltaIteration.name("Scatter-gather iteration (" + this.gatherFunction + " | " + this.scatterFunction + ")");
            return;
        }
        deltaIteration.name(this.configuration.getName("Scatter-gather iteration (" + this.gatherFunction + " | " + this.scatterFunction + ")"));
        deltaIteration.parallelism(this.configuration.getParallelism());
        deltaIteration.setSolutionSetUnManaged(this.configuration.isSolutionSetUnmanagedMemory());
        for (Map.Entry<String, Aggregator<?>> entry : this.configuration.getAggregators().entrySet()) {
            deltaIteration.registerAggregator(entry.getKey(), entry.getValue());
        }
    }

    private DataSet<Vertex<K, VV>> createResultSimpleVertex(EdgeDirection edgeDirection, TypeInformation<Tuple2<K, Message>> typeInformation, DataSet<LongValue> dataSet) {
        CoGroupOperator<?, ?, Tuple2<K, Message>> union;
        TypeInformation type = this.initialVertices.getType();
        DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iterateDelta = this.initialVertices.iterateDelta(this.initialVertices, this.maximumNumberOfIterations, new int[]{0});
        setUpIteration(iterateDelta);
        switch (edgeDirection) {
            case IN:
                union = buildScatterFunction(iterateDelta, typeInformation, 1, 0, dataSet);
                break;
            case OUT:
                union = buildScatterFunction(iterateDelta, typeInformation, 0, 0, dataSet);
                break;
            case ALL:
                union = buildScatterFunction(iterateDelta, typeInformation, 1, 0, dataSet).union(buildScatterFunction(iterateDelta, typeInformation, 0, 0, dataSet));
                break;
            default:
                throw new IllegalArgumentException("Illegal edge direction");
        }
        CoGroupOperator with = union.coGroup(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).with(new GatherUdfSimpleVV(this.gatherFunction, type));
        if (this.configuration != null && this.configuration.isOptNumVertices()) {
            with = with.withBroadcastSet(dataSet, "number of vertices");
        }
        configureUpdateFunction(with);
        return iterateDelta.closeWith(with, with);
    }

    private DataSet<Vertex<K, VV>> createResultVerticesWithDegrees(Graph<K, VV, EV> graph, EdgeDirection edgeDirection, TypeInformation<Tuple2<K, Message>> typeInformation, DataSet<LongValue> dataSet) {
        CoGroupOperator<?, ?, Tuple2<K, Message>> union;
        this.gatherFunction.setOptDegrees(this.configuration.isOptDegrees());
        TwoInputUdfOperator withForwardedFieldsFirst = this.initialVertices.join(graph.inDegrees().join(graph.outDegrees()).where(new int[]{0}).equalTo(new int[]{0}).with(new FlatJoinFunction<Tuple2<K, LongValue>, Tuple2<K, LongValue>, Tuple3<K, LongValue, LongValue>>() { // from class: org.apache.flink.graph.spargel.ScatterGatherIteration.1
            @Override // org.apache.flink.api.common.functions.FlatJoinFunction
            public void join(Tuple2<K, LongValue> tuple2, Tuple2<K, LongValue> tuple22, Collector<Tuple3<K, LongValue, LongValue>> collector) {
                collector.collect(new Tuple3<>(tuple2.f0, tuple2.f1, tuple22.f1));
            }
        }).withForwardedFieldsFirst(new String[]{"f0;f1"}).withForwardedFieldsSecond(new String[]{"f1"})).where(new int[]{0}).equalTo(new int[]{0}).with(new FlatJoinFunction<Vertex<K, VV>, Tuple3<K, LongValue, LongValue>, Vertex<K, Tuple3<VV, LongValue, LongValue>>>() { // from class: org.apache.flink.graph.spargel.ScatterGatherIteration.2
            @Override // org.apache.flink.api.common.functions.FlatJoinFunction
            public void join(Vertex<K, VV> vertex, Tuple3<K, LongValue, LongValue> tuple3, Collector<Vertex<K, Tuple3<VV, LongValue, LongValue>>> collector) throws Exception {
                collector.collect(new Vertex<>(vertex.getId(), new Tuple3(vertex.getValue(), tuple3.f1, tuple3.f2)));
            }
        }).withForwardedFieldsFirst(new String[]{"f0"});
        TypeInformation type = withForwardedFieldsFirst.getType();
        DeltaIteration<Vertex<K, Tuple3<VV, LongValue, LongValue>>, Vertex<K, Tuple3<VV, LongValue, LongValue>>> iterateDelta = withForwardedFieldsFirst.iterateDelta(withForwardedFieldsFirst, this.maximumNumberOfIterations, new int[]{0});
        setUpIteration(iterateDelta);
        switch (edgeDirection) {
            case IN:
                union = buildScatterFunctionVerticesWithDegrees(iterateDelta, typeInformation, 1, 0, dataSet);
                break;
            case OUT:
                union = buildScatterFunctionVerticesWithDegrees(iterateDelta, typeInformation, 0, 0, dataSet);
                break;
            case ALL:
                union = buildScatterFunctionVerticesWithDegrees(iterateDelta, typeInformation, 1, 0, dataSet).union(buildScatterFunctionVerticesWithDegrees(iterateDelta, typeInformation, 0, 0, dataSet));
                break;
            default:
                throw new IllegalArgumentException("Illegal edge direction");
        }
        CoGroupOperator with = union.coGroup(iterateDelta.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).with(new GatherUdfVVWithDegrees(this.gatherFunction, type));
        if (this.configuration != null && this.configuration.isOptNumVertices()) {
            with = with.withBroadcastSet(dataSet, "number of vertices");
        }
        configureUpdateFunction(with);
        return iterateDelta.closeWith(with, with).map(new MapFunction<Vertex<K, Tuple3<VV, LongValue, LongValue>>, Vertex<K, VV>>() { // from class: org.apache.flink.graph.spargel.ScatterGatherIteration.3
            @Override // org.apache.flink.api.common.functions.MapFunction
            public Vertex<K, VV> map(Vertex<K, Tuple3<VV, LongValue, LongValue>> vertex) {
                return new Vertex<>(vertex.getId(), vertex.getValue().f0);
            }
        });
    }

    private <VVWithDegree> void configureUpdateFunction(CoGroupOperator<?, ?, Vertex<K, VVWithDegree>> coGroupOperator) {
        CoGroupOperator name = coGroupOperator.name("Vertex State Updates");
        if (this.configuration != null) {
            for (Tuple2<String, DataSet<?>> tuple2 : this.configuration.getGatherBcastVars()) {
                name = (CoGroupOperator) name.withBroadcastSet(tuple2.f1, tuple2.f0);
            }
        }
        name.withForwardedFieldsFirst(new String[]{"0"}).withForwardedFieldsSecond(new String[]{"0"});
    }
}
