package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.nemo.common.coder.DecoderFactory;
import org.apache.nemo.common.coder.EncoderFactory;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;

@Requires({CommunicationPatternProperty.class})
/* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/CommonSubexpressionEliminationPass.class */
public final class CommonSubexpressionEliminationPass extends ReshapingPass {
    public CommonSubexpressionEliminationPass() {
        super(CommonSubexpressionEliminationPass.class);
    }

    @Override // java.util.function.Function
    public IRDAG apply(IRDAG irdag) {
        DAGBuilder dAGBuilder = new DAGBuilder();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        irdag.reshapeUnsafely(dag -> {
            dag.topologicalDo(iRVertex -> {
                if (!(iRVertex instanceof OperatorVertex)) {
                    dAGBuilder.addVertex(iRVertex, dag);
                    dag.getIncomingEdgesOf(iRVertex).forEach(iREdge -> {
                        if (!(iREdge.getSrc() instanceof OperatorVertex)) {
                            dAGBuilder.connectVertices(iREdge);
                            return;
                        }
                        OperatorVertex src = iREdge.getSrc();
                        hashMap3.putIfAbsent(src, new HashSet());
                        ((Set) hashMap3.get(src)).add(iREdge);
                    });
                } else {
                    OperatorVertex operatorVertex = (OperatorVertex) iRVertex;
                    hashMap.putIfAbsent(operatorVertex.getTransform(), new ArrayList());
                    ((List) hashMap.get(operatorVertex.getTransform())).add(operatorVertex);
                    dag.getIncomingEdgesOf(operatorVertex).forEach(iREdge2 -> {
                        hashMap2.putIfAbsent(operatorVertex, new HashSet());
                        ((Set) hashMap2.get(operatorVertex)).add(iREdge2);
                        if (iREdge2.getSrc() instanceof OperatorVertex) {
                            OperatorVertex src = iREdge2.getSrc();
                            hashMap3.putIfAbsent(src, new HashSet());
                            ((Set) hashMap3.get(src)).add(iREdge2);
                        }
                    });
                }
            });
            hashMap.forEach((transform, list) -> {
                HashMap hashMap4 = new HashMap();
                list.forEach(operatorVertex -> {
                    Set set = (Set) dag.getIncomingEdgesOf(operatorVertex).stream().map((v0) -> {
                        return v0.getSrc();
                    }).collect(Collectors.toSet());
                    if (hashMap4.keySet().stream().anyMatch(set2 -> {
                        return set2.containsAll(set) && set.containsAll(set2);
                    })) {
                        ((List) hashMap4.get((Set) hashMap4.keySet().stream().filter(set3 -> {
                            return set3.containsAll(set) && set.containsAll(set3);
                        }).findFirst().get())).add(operatorVertex);
                    } else {
                        hashMap4.putIfAbsent(set, new ArrayList());
                        ((List) hashMap4.get(set)).add(operatorVertex);
                    }
                });
                hashMap4.values().forEach(list -> {
                    mergeAndAddToBuilder(list, dAGBuilder, dag, hashMap2, hashMap3);
                });
            });
            hashMap.values().forEach(list2 -> {
                list2.forEach(operatorVertex -> {
                    ((Set) hashMap2.getOrDefault(operatorVertex, new HashSet())).forEach(iREdge -> {
                        if (dAGBuilder.contains(operatorVertex) && dAGBuilder.contains(iREdge.getSrc())) {
                            dAGBuilder.connectVertices(iREdge);
                        }
                    });
                    ((Set) hashMap3.getOrDefault(operatorVertex, new HashSet())).forEach(iREdge2 -> {
                        if (dAGBuilder.contains(operatorVertex) && dAGBuilder.contains(iREdge2.getDst())) {
                            dAGBuilder.connectVertices(iREdge2);
                        }
                    });
                });
            });
            return dAGBuilder.build();
        });
        return irdag;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void mergeAndAddToBuilder(List<OperatorVertex> list, DAGBuilder<IRVertex, IREdge> dAGBuilder, DAG<IRVertex, IREdge> dag, Map<OperatorVertex, Set<IREdge>> map, Map<OperatorVertex, Set<IREdge>> map2) {
        if (list.isEmpty()) {
            return;
        }
        OperatorVertex operatorVertex = list.get(0);
        ArrayList arrayList = new ArrayList();
        dAGBuilder.addVertex(operatorVertex);
        list.forEach(operatorVertex2 -> {
            if (operatorVertex2.equals(operatorVertex)) {
                return;
            }
            if (dag.pathExistsBetween(operatorVertex, operatorVertex2).booleanValue()) {
                arrayList.add(operatorVertex2);
                return;
            }
            Set set = (Set) map2.get(operatorVertex2);
            ((Set) map2.getOrDefault(operatorVertex2, new HashSet())).forEach(iREdge -> {
                set.remove(iREdge);
                IREdge iREdge = new IREdge((CommunicationPatternProperty.Value) iREdge.getPropertyValue(CommunicationPatternProperty.class).get(), operatorVertex, iREdge.getDst());
                Optional propertyValue = iREdge.getPropertyValue(EncoderProperty.class);
                if (propertyValue.isPresent()) {
                    iREdge.setProperty(EncoderProperty.of((EncoderFactory) propertyValue.get()));
                }
                Optional propertyValue2 = iREdge.getPropertyValue(DecoderProperty.class);
                if (propertyValue2.isPresent()) {
                    iREdge.setProperty(DecoderProperty.of((DecoderFactory) propertyValue2.get()));
                }
                set.add(iREdge);
            });
            map2.remove(operatorVertex2);
            map2.putIfAbsent(operatorVertex, new HashSet());
            ((Set) map2.get(operatorVertex)).addAll(set);
        });
        mergeAndAddToBuilder(arrayList, dAGBuilder, dag, map, map2);
    }
}
