package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.nemo.common.coder.DecoderFactory;
import org.apache.nemo.common.coder.EncoderFactory;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;

/* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.class */
public final class LoopOptimizations {

    @Requires({CommunicationPatternProperty.class})
    /* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations$LoopFusionPass.class */
    public static final class LoopFusionPass extends ReshapingPass {
        public LoopFusionPass() {
            super(LoopFusionPass.class);
        }

        @Override // java.util.function.Function
        public IRDAG apply(IRDAG irdag) {
            irdag.reshapeUnsafely(dag -> {
                ArrayList arrayList = new ArrayList();
                HashMap hashMap = new HashMap();
                HashMap hashMap2 = new HashMap();
                DAGBuilder dAGBuilder = new DAGBuilder();
                LoopOptimizations.collectLoopVertices(dag, arrayList, hashMap, hashMap2, dAGBuilder);
                HashSet hashSet = new HashSet();
                arrayList.forEach(loopVertex -> {
                    List list = (List) arrayList.stream().filter(loopVertex -> {
                        return hashSet.stream().anyMatch(set -> {
                            return set.contains(loopVertex);
                        }) ? ((Boolean) hashSet.stream().filter(set2 -> {
                            return set2.contains(loopVertex);
                        }).findFirst().map(set3 -> {
                            return Boolean.valueOf(set3.stream().noneMatch(loopVertex -> {
                                return dag.pathExistsBetween(loopVertex, loopVertex).booleanValue();
                            }));
                        }).orElse(false)).booleanValue() : !dag.pathExistsBetween(loopVertex, loopVertex).booleanValue();
                    }).collect(Collectors.toList());
                    HashSet hashSet2 = new HashSet();
                    hashSet2.add(loopVertex);
                    list.forEach(loopVertex2 -> {
                        if (loopVertex.terminationConditionEquals(loopVertex2)) {
                            hashSet2.add(loopVertex2);
                        }
                    });
                    Optional findFirst = hashSet.stream().filter(set -> {
                        Stream stream = set.stream();
                        Objects.requireNonNull(hashSet2);
                        return stream.anyMatch((v1) -> {
                            return r1.contains(v1);
                        });
                    }).findFirst();
                    if (findFirst.isPresent()) {
                        ((Set) findFirst.get()).addAll(hashSet2);
                    } else {
                        hashSet.add(hashSet2);
                    }
                });
                hashSet.forEach(set -> {
                    if (set.size() <= 1) {
                        set.forEach(loopVertex2 -> {
                            dAGBuilder.addVertex(loopVertex2);
                            ((List) hashMap.getOrDefault(loopVertex2, new ArrayList())).forEach(iREdge -> {
                                if (dAGBuilder.contains(iREdge.getSrc())) {
                                    dAGBuilder.connectVertices(iREdge);
                                }
                            });
                            ((List) hashMap2.getOrDefault(loopVertex2, new ArrayList())).forEach(iREdge2 -> {
                                if (dAGBuilder.contains(iREdge2.getDst())) {
                                    dAGBuilder.connectVertices(iREdge2);
                                }
                            });
                        });
                        return;
                    }
                    LoopVertex mergeLoopVertices = mergeLoopVertices(set);
                    dAGBuilder.addVertex(mergeLoopVertices, dag);
                    set.forEach(loopVertex3 -> {
                        ((List) hashMap.getOrDefault(loopVertex3, new ArrayList())).forEach(iREdge -> {
                            if (dAGBuilder.contains(iREdge.getSrc())) {
                                IREdge iREdge = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), iREdge.getSrc(), mergeLoopVertices);
                                iREdge.copyExecutionPropertiesTo(iREdge);
                                dAGBuilder.connectVertices(iREdge);
                            }
                        });
                        ((List) hashMap2.getOrDefault(loopVertex3, new ArrayList())).forEach(iREdge2 -> {
                            if (dAGBuilder.contains(iREdge2.getDst())) {
                                IREdge iREdge2 = new IREdge((CommunicationPatternProperty.Value) iREdge2.getPropertyValue(CommunicationPatternProperty.class).get(), mergeLoopVertices, iREdge2.getDst());
                                iREdge2.copyExecutionPropertiesTo(iREdge2);
                                dAGBuilder.connectVertices(iREdge2);
                            }
                        });
                    });
                });
                return dAGBuilder.build();
            });
            return irdag;
        }

        private LoopVertex mergeLoopVertices(Set<LoopVertex> set) {
            LoopVertex loopVertex = new LoopVertex(String.join("+", (Iterable<? extends CharSequence>) set.stream().map((v0) -> {
                return v0.getName();
            }).collect(Collectors.toList())));
            set.forEach(loopVertex2 -> {
                DAG dag = loopVertex2.getDAG();
                dag.topologicalDo(iRVertex -> {
                    loopVertex.getBuilder().addVertex(iRVertex);
                    List incomingEdgesOf = dag.getIncomingEdgesOf(iRVertex);
                    DAGBuilder builder = loopVertex.getBuilder();
                    Objects.requireNonNull(builder);
                    incomingEdgesOf.forEach((v1) -> {
                        r1.connectVertices(v1);
                    });
                });
                loopVertex2.getDagIncomingEdges().forEach((iRVertex2, set2) -> {
                    Objects.requireNonNull(loopVertex);
                    set2.forEach(loopVertex::addDagIncomingEdge);
                });
                loopVertex2.getIterativeIncomingEdges().forEach((iRVertex3, set3) -> {
                    Objects.requireNonNull(loopVertex);
                    set3.forEach(loopVertex::addIterativeIncomingEdge);
                });
                loopVertex2.getNonIterativeIncomingEdges().forEach((iRVertex4, set4) -> {
                    Objects.requireNonNull(loopVertex);
                    set4.forEach(loopVertex::addNonIterativeIncomingEdge);
                });
                loopVertex2.getDagOutgoingEdges().forEach((iRVertex5, set5) -> {
                    Objects.requireNonNull(loopVertex);
                    set5.forEach(loopVertex::addDagOutgoingEdge);
                });
            });
            return loopVertex;
        }
    }

    @Requires({CommunicationPatternProperty.class})
    /* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations$LoopInvariantCodeMotionPass.class */
    public static final class LoopInvariantCodeMotionPass extends ReshapingPass {
        public LoopInvariantCodeMotionPass() {
            super(LoopInvariantCodeMotionPass.class);
        }

        @Override // java.util.function.Function
        public IRDAG apply(IRDAG irdag) {
            irdag.reshapeUnsafely(this::recursivelyOptimize);
            return irdag;
        }

        DAG<IRVertex, IREdge> recursivelyOptimize(DAG<IRVertex, IREdge> dag) {
            ArrayList arrayList = new ArrayList();
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            DAGBuilder dAGBuilder = new DAGBuilder();
            LoopOptimizations.collectLoopVertices(dag, arrayList, hashMap, hashMap2, dAGBuilder);
            arrayList.forEach(loopVertex -> {
                ((List) loopVertex.getNonIterativeIncomingEdges().entrySet().stream().filter(entry -> {
                    return loopVertex.getDAG().getIncomingEdgesOf((IRVertex) entry.getKey()).isEmpty() && ((Set) loopVertex.getIterativeIncomingEdges().getOrDefault(entry.getKey(), new HashSet())).isEmpty();
                }).collect(Collectors.toList())).forEach(entry2 -> {
                    dAGBuilder.addVertex((IRVertex) entry2.getKey());
                    Set set = (Set) entry2.getValue();
                    Objects.requireNonNull(dAGBuilder);
                    set.forEach((v1) -> {
                        r1.connectVertices(v1);
                    });
                    List outgoingEdgesOf = loopVertex.getDAG().getOutgoingEdgesOf((IRVertex) entry2.getKey());
                    Objects.requireNonNull(loopVertex);
                    outgoingEdgesOf.forEach(loopVertex::addDagIncomingEdge);
                    List outgoingEdgesOf2 = loopVertex.getDAG().getOutgoingEdgesOf((IRVertex) entry2.getKey());
                    Objects.requireNonNull(loopVertex);
                    outgoingEdgesOf2.forEach(loopVertex::addNonIterativeIncomingEdge);
                    ArrayList arrayList2 = new ArrayList();
                    ArrayList arrayList3 = new ArrayList();
                    ((List) hashMap.getOrDefault(loopVertex, new ArrayList())).stream().filter(iREdge -> {
                        return ((Set) entry2.getValue()).stream().map((v0) -> {
                            return v0.getSrc();
                        }).anyMatch(iRVertex -> {
                            return iRVertex.equals(iREdge.getSrc());
                        });
                    }).forEach(iREdge2 -> {
                        arrayList2.add(iREdge2);
                        IREdge iREdge2 = new IREdge((CommunicationPatternProperty.Value) iREdge2.getPropertyValue(CommunicationPatternProperty.class).get(), (IRVertex) entry2.getKey(), iREdge2.getDst());
                        iREdge2.setProperty(EncoderProperty.of((EncoderFactory) iREdge2.getPropertyValue(EncoderProperty.class).get()));
                        iREdge2.setProperty(DecoderProperty.of((DecoderFactory) iREdge2.getPropertyValue(DecoderProperty.class).get()));
                        arrayList3.add(iREdge2);
                    });
                    List list = (List) hashMap.getOrDefault(loopVertex, new ArrayList());
                    list.removeAll(arrayList2);
                    list.addAll(arrayList3);
                    loopVertex.getBuilder().removeVertex((IRVertex) entry2.getKey());
                    loopVertex.getDagIncomingEdges().remove(entry2.getKey());
                    loopVertex.getNonIterativeIncomingEdges().remove(entry2.getKey());
                });
            });
            arrayList.forEach(loopVertex2 -> {
                dAGBuilder.addVertex(loopVertex2);
                List list = (List) hashMap.getOrDefault(loopVertex2, new ArrayList());
                Objects.requireNonNull(dAGBuilder);
                list.forEach((v1) -> {
                    r1.connectVertices(v1);
                });
                List list2 = (List) hashMap2.getOrDefault(loopVertex2, new ArrayList());
                Objects.requireNonNull(dAGBuilder);
                list2.forEach((v1) -> {
                    r1.connectVertices(v1);
                });
            });
            DAG<IRVertex, IREdge> build = dAGBuilder.build();
            return dag.getVertices().size() == build.getVertices().size() ? build : recursivelyOptimize(build);
        }
    }

    private LoopOptimizations() {
    }

    public static LoopFusionPass getLoopFusionPass() {
        return new LoopFusionPass();
    }

    public static LoopInvariantCodeMotionPass getLoopInvariantCodeMotionPass() {
        return new LoopInvariantCodeMotionPass();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void collectLoopVertices(DAG<IRVertex, IREdge> dag, List<LoopVertex> list, Map<LoopVertex, List<IREdge>> map, Map<LoopVertex, List<IREdge>> map2, DAGBuilder<IRVertex, IREdge> dAGBuilder) {
        dag.topologicalDo(iRVertex -> {
            if (!(iRVertex instanceof LoopVertex)) {
                dAGBuilder.addVertex(iRVertex, dag);
                dag.getIncomingEdgesOf(iRVertex).forEach(iREdge -> {
                    if (!(iREdge.getSrc() instanceof LoopVertex)) {
                        dAGBuilder.connectVertices(iREdge);
                        return;
                    }
                    LoopVertex src = iREdge.getSrc();
                    map2.putIfAbsent(src, new ArrayList());
                    ((List) map2.get(src)).add(iREdge);
                });
            } else {
                LoopVertex loopVertex = (LoopVertex) iRVertex;
                list.add(loopVertex);
                dag.getIncomingEdgesOf(loopVertex).forEach(iREdge2 -> {
                    map.putIfAbsent(loopVertex, new ArrayList());
                    ((List) map.get(loopVertex)).add(iREdge2);
                    if (iREdge2.getSrc() instanceof LoopVertex) {
                        LoopVertex src = iREdge2.getSrc();
                        map2.putIfAbsent(src, new ArrayList());
                        ((List) map2.get(src)).add(iREdge2);
                    }
                });
            }
        });
    }
}
