package org.apache.nemo.runtime.common.plan;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.jcip.annotations.ThreadSafe;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.reef.annotations.audience.DriverSide;

@DriverSide
@ThreadSafe
/* loaded from: input_file:org/apache/nemo/runtime/common/plan/StagePartitioner.class */
public final class StagePartitioner implements Function<IRDAG, Map<IRVertex, Integer>> {
    private final Set<Class<? extends VertexExecutionProperty>> ignoredPropertyKeys = ConcurrentHashMap.newKeySet();
    private final MutableInt nextStageIndex = new MutableInt(0);

    public void addIgnoredPropertyKey(Class<? extends VertexExecutionProperty> cls) {
        this.ignoredPropertyKeys.add(cls);
    }

    @Override // java.util.function.Function
    public Map<IRVertex, Integer> apply(IRDAG irdag) {
        HashMap hashMap = new HashMap();
        irdag.topologicalDo(iRVertex -> {
            if (hashMap.get(iRVertex) == null) {
                hashMap.put(iRVertex, this.nextStageIndex.getValue());
                this.nextStageIndex.increment();
            }
            int intValue = ((Integer) hashMap.get(iRVertex)).intValue();
            for (IREdge iREdge : irdag.getOutgoingEdgesOf(iRVertex)) {
                IRVertex dst = iREdge.getDst();
                if (!hashMap.containsKey(dst)) {
                    if (testMergeability(iREdge, irdag)) {
                        hashMap.put(dst, Integer.valueOf(intValue));
                    } else {
                        hashMap.put(dst, this.nextStageIndex.getValue());
                        this.nextStageIndex.increment();
                    }
                }
            }
        });
        return hashMap;
    }

    private boolean testMergeability(IREdge iREdge, IRDAG irdag) {
        if (irdag.getIncomingEdgesOf(iREdge.getDst()).size() <= 1 && iREdge.getPropertyValue(CommunicationPatternProperty.class).orElseThrow(IllegalStateException::new) == CommunicationPatternProperty.Value.ONE_TO_ONE) {
            return getStageProperties((IRVertex) iREdge.getSrc()).equals(getStageProperties((IRVertex) iREdge.getDst()));
        }
        return false;
    }

    public Set<VertexExecutionProperty> getStageProperties(IRVertex iRVertex) {
        return (Set) iRVertex.getExecutionProperties().stream().filter(vertexExecutionProperty -> {
            return !this.ignoredPropertyKeys.contains(vertexExecutionProperty.getClass());
        }).collect(Collectors.toSet());
    }
}
