package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.IdManager;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;

@Requires({CommunicationPatternProperty.class})
/* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPass.class */
public final class LoopExtractionPass extends ReshapingPass {
    public LoopExtractionPass() {
        super(LoopExtractionPass.class);
    }

    @Override // java.util.function.Function
    public IRDAG apply(IRDAG irdag) {
        irdag.reshapeUnsafely(dag -> {
            return groupLoops(dag, findMaxLoopVertexStackDepth(dag));
        });
        return irdag;
    }

    private Integer findMaxLoopVertexStackDepth(DAG<IRVertex, IREdge> dag) {
        Stream stream = dag.getVertices().stream();
        Objects.requireNonNull(dag);
        Stream filter = stream.filter((v1) -> {
            return r1.isCompositeVertex(v1);
        });
        Objects.requireNonNull(dag);
        return Integer.valueOf(filter.mapToInt((v1) -> {
            return r1.getLoopStackDepthOf(v1);
        }).max().orElse(0));
    }

    private DAG<IRVertex, IREdge> groupLoops(DAG<IRVertex, IREdge> dag, Integer num) {
        if (num.intValue() <= 0) {
            return dag;
        }
        DAGBuilder dAGBuilder = new DAGBuilder();
        for (LoopVertex loopVertex : dag.getTopologicalSort()) {
            if (loopVertex instanceof SourceVertex) {
                if (dag.isCompositeVertex(loopVertex).booleanValue() && dag.getLoopStackDepthOf(loopVertex).equals(num)) {
                    LoopVertex assignedLoopVertexOf = dag.getAssignedLoopVertexOf(loopVertex);
                    dAGBuilder.addVertex(assignedLoopVertexOf, dag);
                    connectElementToLoop(dag, dAGBuilder, loopVertex, assignedLoopVertexOf);
                } else {
                    dAGBuilder.addVertex(loopVertex, dag);
                }
            } else if (loopVertex instanceof OperatorVertex) {
                OperatorVertex operatorVertex = (OperatorVertex) loopVertex;
                if (dag.isCompositeVertex(operatorVertex).booleanValue() && dag.getLoopStackDepthOf(operatorVertex).equals(num)) {
                    LoopVertex assignedLoopVertexOf2 = dag.getAssignedLoopVertexOf(operatorVertex);
                    dAGBuilder.addVertex(assignedLoopVertexOf2, dag);
                    connectElementToLoop(dag, dAGBuilder, operatorVertex, assignedLoopVertexOf2);
                } else {
                    dAGBuilder.addVertex(operatorVertex, dag);
                    dag.getIncomingEdgesOf(operatorVertex).forEach(iREdge -> {
                        if (!dag.isCompositeVertex(iREdge.getSrc()).booleanValue()) {
                            dAGBuilder.connectVertices(iREdge);
                            return;
                        }
                        LoopVertex assignedLoopVertexOf3 = dag.getAssignedLoopVertexOf(iREdge.getSrc());
                        assignedLoopVertexOf3.addDagOutgoingEdge(iREdge);
                        IREdge iREdge = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), assignedLoopVertexOf3, operatorVertex);
                        iREdge.copyExecutionPropertiesTo(iREdge);
                        dAGBuilder.connectVertices(iREdge);
                        assignedLoopVertexOf3.mapEdgeWithLoop(iREdge, iREdge);
                    });
                }
            } else {
                if (!(loopVertex instanceof LoopVertex)) {
                    throw new UnsupportedOperationException("Unknown vertex type: " + loopVertex);
                }
                LoopVertex loopVertex2 = loopVertex;
                if (!dag.isCompositeVertex(loopVertex2).booleanValue()) {
                    throw new UnsupportedOperationException("This loop (" + loopVertex2 + ") shouldn't be of this depth");
                }
                connectElementToLoop(dag, dAGBuilder, loopVertex2, dag.getAssignedLoopVertexOf(loopVertex2));
            }
        }
        return groupLoops(loopRolling(dAGBuilder.build()), Integer.valueOf(num.intValue() - 1));
    }

    private static void connectElementToLoop(DAG<IRVertex, IREdge> dag, DAGBuilder<IRVertex, IREdge> dAGBuilder, IRVertex iRVertex, LoopVertex loopVertex) {
        loopVertex.getBuilder().addVertex(iRVertex, dag);
        dag.getIncomingEdgesOf(iRVertex).forEach(iREdge -> {
            if (!dag.isCompositeVertex(iREdge.getSrc()).booleanValue()) {
                loopVertex.addDagIncomingEdge(iREdge);
                IREdge iREdge = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), iREdge.getSrc(), loopVertex);
                iREdge.copyExecutionPropertiesTo(iREdge);
                dAGBuilder.connectVertices(iREdge);
                loopVertex.mapEdgeWithLoop(iREdge, iREdge);
                return;
            }
            LoopVertex assignedLoopVertexOf = dag.getAssignedLoopVertexOf(iREdge.getSrc());
            if (assignedLoopVertexOf.equals(loopVertex)) {
                loopVertex.getBuilder().connectVertices(iREdge);
                return;
            }
            loopVertex.addDagIncomingEdge(iREdge);
            IREdge iREdge2 = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), assignedLoopVertexOf, loopVertex);
            iREdge.copyExecutionPropertiesTo(iREdge2);
            dAGBuilder.connectVertices(iREdge2);
            loopVertex.mapEdgeWithLoop(iREdge2, iREdge);
        });
    }

    private DAG<IRVertex, IREdge> loopRolling(DAG<IRVertex, IREdge> dag) {
        DAGBuilder dAGBuilder = new DAGBuilder();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        LoopVertex loopVertex = null;
        for (IRVertex iRVertex : dag.getTopologicalSort()) {
            if (iRVertex instanceof SourceVertex) {
                dAGBuilder.addVertex(iRVertex, dag);
            } else if (iRVertex instanceof OperatorVertex) {
                addVertexToBuilder(dAGBuilder, dag, iRVertex, hashMap);
            } else {
                if (!(iRVertex instanceof LoopVertex)) {
                    throw new UnsupportedOperationException("Unknown vertex type: " + iRVertex);
                }
                LoopVertex loopVertex2 = (LoopVertex) iRVertex;
                if (loopVertex == null || !loopVertex2.getName().contains(loopVertex.getName())) {
                    loopVertex = loopVertex2;
                    hashMap.putIfAbsent(loopVertex, loopVertex);
                    hashMap2.putIfAbsent(loopVertex, new HashMap());
                    for (IRVertex iRVertex2 : loopVertex.getDAG().getTopologicalSort()) {
                        ((HashMap) hashMap2.get(loopVertex)).putIfAbsent(iRVertex2, iRVertex2);
                        IdManager.saveVertexId(iRVertex2, iRVertex2.getId());
                    }
                    addVertexToBuilder(dAGBuilder, dag, loopVertex, hashMap);
                } else {
                    LoopVertex loopVertex3 = loopVertex;
                    hashMap.putIfAbsent(loopVertex2, loopVertex3);
                    loopVertex3.increaseMaxNumberOfIterations();
                    Iterator it = loopVertex3.getDAG().getTopologicalSort().iterator();
                    Iterator it2 = loopVertex2.getDAG().getTopologicalSort().iterator();
                    HashMap hashMap3 = (HashMap) hashMap2.get(loopVertex3);
                    while (it.hasNext() && it2.hasNext()) {
                        IRVertex iRVertex3 = (IRVertex) it2.next();
                        IRVertex iRVertex4 = (IRVertex) it.next();
                        hashMap3.put(iRVertex3, iRVertex4);
                        IdManager.saveVertexId(iRVertex4, iRVertex3.getId());
                    }
                    loopVertex3.getNonIterativeIncomingEdges().clear();
                    loopVertex3.getIterativeIncomingEdges().clear();
                    loopVertex2.getDagIncomingEdges().forEach((iRVertex5, set) -> {
                        set.forEach(iREdge -> {
                            IRVertex src = iREdge.getSrc();
                            IRVertex iRVertex5 = (IRVertex) hashMap3.get(iRVertex5);
                            if (!hashMap3.containsKey(src)) {
                                IREdge iREdge = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), src, iRVertex5);
                                iREdge.copyExecutionPropertiesTo(iREdge);
                                loopVertex3.addNonIterativeIncomingEdge(iREdge);
                            } else {
                                IREdge iREdge2 = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), (IRVertex) hashMap3.get(src), iRVertex5);
                                iREdge.copyExecutionPropertiesTo(iREdge2);
                                loopVertex3.addIterativeIncomingEdge(iREdge2);
                            }
                        });
                    });
                    loopVertex3.getDagOutgoingEdges().clear();
                    loopVertex2.getDagOutgoingEdges().forEach((iRVertex6, set2) -> {
                        set2.forEach(iREdge -> {
                            IRVertex dst = iREdge.getDst();
                            IREdge iREdge = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), (IRVertex) hashMap3.get(iRVertex6), dst);
                            iREdge.copyExecutionPropertiesTo(iREdge);
                            loopVertex3.addDagOutgoingEdge(iREdge);
                            loopVertex3.mapEdgeWithLoop(loopVertex2.getEdgeWithLoop(iREdge), iREdge);
                        });
                    });
                }
            }
        }
        return dAGBuilder.build();
    }

    private static void addVertexToBuilder(DAGBuilder<IRVertex, IREdge> dAGBuilder, DAG<IRVertex, IREdge> dag, IRVertex iRVertex, Map<LoopVertex, LoopVertex> map) {
        dAGBuilder.addVertex(iRVertex, dag);
        dag.getIncomingEdgesOf(iRVertex).forEach(iREdge -> {
            IRVertex src;
            if (iREdge.getSrc() instanceof LoopVertex) {
                IRVertex iRVertex2 = (LoopVertex) map.get(iREdge.getSrc());
                src = iRVertex2 != null ? iRVertex2 : iREdge.getSrc();
            } else {
                src = iREdge.getSrc();
            }
            if (iREdge.getSrc().equals(src)) {
                dAGBuilder.connectVertices(iREdge);
                return;
            }
            IREdge iREdge = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), src, iRVertex);
            iREdge.copyExecutionPropertiesTo(iREdge);
            dAGBuilder.connectVertices(iREdge);
        });
    }
}
