package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.nemo.common.KeyExtractor;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DataStoreProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Requires({CommunicationPatternProperty.class})
/* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.class */
public final class SamplingSkewReshapingPass extends ReshapingPass {
    private static final Logger LOG = LoggerFactory.getLogger(SamplingSkewReshapingPass.class.getName());
    private static final float SAMPLE_RATE = 0.1f;

    public SamplingSkewReshapingPass() {
        super(SamplingSkewReshapingPass.class);
    }

    @Override // java.util.function.Function
    public IRDAG apply(IRDAG irdag) {
        irdag.topologicalDo(iRVertex -> {
            for (IREdge iREdge : irdag.getIncomingEdgesOf(iRVertex)) {
                if (CommunicationPatternProperty.Value.SHUFFLE.equals(iREdge.getPropertyValue(CommunicationPatternProperty.class).get())) {
                    Set<IRVertex> recursivelyBuildPartition = recursivelyBuildPartition((IRVertex) iREdge.getSrc(), irdag);
                    Set set = (Set) recursivelyBuildPartition.stream().filter(iRVertex -> {
                        Stream map = irdag.getIncomingEdgesOf(iRVertex).stream().map((v0) -> {
                            return v0.getSrc();
                        });
                        Objects.requireNonNull(recursivelyBuildPartition);
                        return !map.anyMatch((v1) -> {
                            return r1.contains(v1);
                        });
                    }).collect(Collectors.toSet());
                    Stream map = recursivelyBuildPartition.stream().flatMap(iRVertex2 -> {
                        return irdag.getOutgoingEdgesOf(iRVertex2).stream();
                    }).map((v0) -> {
                        return v0.getDst();
                    });
                    Objects.requireNonNull(recursivelyBuildPartition);
                    if (map.allMatch((v1) -> {
                        return r1.contains(v1);
                    })) {
                        return;
                    }
                    Set set2 = (Set) recursivelyBuildPartition.stream().map(iRVertex3 -> {
                        return new SamplingVertex(iRVertex3, SAMPLE_RATE);
                    }).collect(Collectors.toSet());
                    irdag.insert(set2, set);
                    irdag.insert(new MessageGeneratorVertex(SkewHandlingUtil.getMessageGenerator((KeyExtractor) iREdge.getPropertyValue(KeyExtractorProperty.class).get())), new MessageAggregatorVertex(HashMap::new, SkewHandlingUtil.getMessageAggregator()), SkewHandlingUtil.getEncoder(iREdge), SkewHandlingUtil.getDecoder(iREdge), new HashSet(Arrays.asList(((SamplingVertex) set2.stream().filter(samplingVertex -> {
                        return samplingVertex.getOriginalVertexId().equals(iREdge.getSrc().getId());
                    }).findFirst().orElseThrow(IllegalStateException::new)).getCloneOfOriginalEdge(iREdge))), new HashSet(Arrays.asList(iREdge)));
                }
            }
        });
        return irdag;
    }

    private Set<IRVertex> recursivelyBuildPartition(IRVertex iRVertex, IRDAG irdag) {
        HashSet hashSet = new HashSet();
        hashSet.add(iRVertex);
        for (IREdge iREdge : irdag.getIncomingEdgesOf(iRVertex)) {
            if (CommunicationPatternProperty.Value.ONE_TO_ONE.equals(iREdge.getPropertyValue(CommunicationPatternProperty.class).orElseThrow(IllegalStateException::new)) && DataStoreProperty.Value.MEMORY_STORE.equals(iREdge.getPropertyValue(DataStoreProperty.class).orElseThrow(IllegalStateException::new)) && irdag.getIncomingEdgesOf(iRVertex).size() == 1) {
                hashSet.addAll(recursivelyBuildPartition((IRVertex) iREdge.getSrc(), irdag));
            }
        }
        return hashSet;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1818100338:
                if (implMethodName.equals("<init>")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageAggregatorVertex$InitialStateSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("java/util/HashMap") && serializedLambda.getImplMethodSignature().equals("()V")) {
                    return HashMap::new;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
