package org.apache.nemo.compiler.frontend.beam;

import java.util.HashMap;
import java.util.Map;
import java.util.Stack;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.MapCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PCollectionViews;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyDecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyEncoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.common.ir.vertex.transform.Transform;
import org.apache.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
import org.apache.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
import org.apache.nemo.compiler.frontend.beam.coder.SideInputCoder;
import org.apache.nemo.compiler.frontend.beam.transform.CreateViewTransform;
import org.apache.nemo.compiler.frontend.beam.transform.DoFnTransform;
import org.apache.nemo.compiler.frontend.beam.transform.FlattenTransform;
import org.apache.nemo.compiler.frontend.beam.transform.GBKTransform;
import org.apache.nemo.compiler.frontend.beam.transform.GroupByKeyTransform;
import org.apache.nemo.compiler.frontend.beam.transform.LoopCompositeTransform;
import org.apache.nemo.compiler.frontend.beam.transform.SideInputTransform;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.class */
public final class PipelineTranslationContext {
    private final PipelineOptions pipelineOptions;
    private final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
    private final Map<PValue, TransformHierarchy.Node> pValueToProducerBeamNode = new HashMap();
    private final Map<PValue, IRVertex> pValueToProducerVertex = new HashMap();
    private final Map<PValue, TupleTag<?>> pValueToTag = new HashMap();
    private final Stack<LoopVertex> loopVertexStack = new Stack<>();
    private final Pipeline pipeline;

    /* JADX INFO: Access modifiers changed from: package-private */
    public PipelineTranslationContext(Pipeline pipeline, PipelineOptions pipelineOptions) {
        this.pipeline = pipeline;
        this.pipelineOptions = pipelineOptions;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void enterCompositeTransform(TransformHierarchy.Node node) {
        if (node.getTransform() instanceof LoopCompositeTransform) {
            LoopVertex loopVertex = new LoopVertex(node.getFullName());
            this.builder.addVertex(loopVertex, this.loopVertexStack);
            this.builder.removeVertex(loopVertex);
            this.loopVertexStack.push(new LoopVertex(node.getFullName()));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void leaveCompositeTransform(TransformHierarchy.Node node) {
        if (node.getTransform() instanceof LoopCompositeTransform) {
            this.loopVertexStack.pop();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addVertex(IRVertex iRVertex) {
        this.builder.addVertex(iRVertex, this.loopVertexStack);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addSideInputEdges(IRVertex iRVertex, Map<Integer, PCollectionView<?>> map) {
        for (Map.Entry<Integer, PCollectionView<?>> entry : map.entrySet()) {
            int intValue = entry.getKey().intValue();
            PCollectionView<?> value = entry.getValue();
            IRVertex iRVertex2 = this.pValueToProducerVertex.get(value);
            OperatorVertex operatorVertex = new OperatorVertex(new SideInputTransform(intValue));
            addVertex(operatorVertex);
            Coder<?> coderForView = getCoderForView(value, this);
            Coder windowCoder = value.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
            addEdge(new IREdge(CommunicationPatternProperty.Value.ONE_TO_ONE, iRVertex2, operatorVertex), coderForView, windowCoder);
            IREdge iREdge = new IREdge(CommunicationPatternProperty.Value.BROADCAST, operatorVertex, iRVertex);
            WindowedValue.FullWindowedValueCoder fullCoder = WindowedValue.getFullCoder(SideInputCoder.of(coderForView), windowCoder);
            iRVertex2.setPropertyPermanently(ParallelismProperty.of(1));
            operatorVertex.setPropertyPermanently(ParallelismProperty.of(1));
            iREdge.setProperty(EncoderProperty.of(new BeamEncoderFactory(fullCoder)));
            iREdge.setProperty(DecoderProperty.of(new BeamDecoderFactory(fullCoder)));
            this.builder.connectVertices(iREdge);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addEdgeTo(IRVertex iRVertex, PValue pValue) {
        if (!(pValue instanceof PCollection)) {
            throw new IllegalStateException(pValue.toString());
        }
        Coder coder = ((PCollection) pValue).getCoder();
        Coder windowCoder = ((PCollection) pValue).getWindowingStrategy().getWindowFn().windowCoder();
        IRVertex iRVertex2 = this.pValueToProducerVertex.get(pValue);
        if (iRVertex2 == null) {
            throw new IllegalStateException(String.format("Cannot find a vertex that emits pValue %s", pValue));
        }
        IREdge iREdge = new IREdge(getCommPattern(iRVertex2, iRVertex), iRVertex2, iRVertex);
        if (this.pValueToTag.containsKey(pValue)) {
            iREdge.setProperty(AdditionalOutputTagProperty.of(this.pValueToTag.get(pValue).getId()));
        }
        addEdge(iREdge, coder, windowCoder);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addEdge(IREdge iREdge, Coder coder, Coder coder2) {
        iREdge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
        if (coder instanceof KvCoder) {
            Coder keyCoder = ((KvCoder) coder).getKeyCoder();
            iREdge.setProperty(KeyEncoderProperty.of(new BeamEncoderFactory(keyCoder)));
            iREdge.setProperty(KeyDecoderProperty.of(new BeamDecoderFactory(keyCoder)));
        }
        WindowedValue.FullWindowedValueCoder fullCoder = WindowedValue.getFullCoder(coder, coder2);
        iREdge.setProperty(EncoderProperty.of(new BeamEncoderFactory(fullCoder)));
        iREdge.setProperty(DecoderProperty.of(new BeamDecoderFactory(fullCoder)));
        this.builder.connectVertices(iREdge);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void registerMainOutputFrom(TransformHierarchy.Node node, IRVertex iRVertex, PValue pValue) {
        this.pValueToProducerBeamNode.put(pValue, node);
        this.pValueToProducerVertex.put(pValue, iRVertex);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void registerAdditionalOutputFrom(TransformHierarchy.Node node, IRVertex iRVertex, PValue pValue, TupleTag<?> tupleTag) {
        this.pValueToProducerBeamNode.put(pValue, node);
        this.pValueToTag.put(pValue, tupleTag);
        this.pValueToProducerVertex.put(pValue, iRVertex);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Pipeline getPipeline() {
        return this.pipeline;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public PipelineOptions getPipelineOptions() {
        return this.pipelineOptions;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DAGBuilder getBuilder() {
        return this.builder;
    }

    TransformHierarchy.Node getProducerBeamNodeOf(PValue pValue) {
        return this.pValueToProducerBeamNode.get(pValue);
    }

    private CommunicationPatternProperty.Value getCommPattern(IRVertex iRVertex, IRVertex iRVertex2) {
        try {
            Class<?> cls = Class.forName("org.apache.beam.sdk.transforms.join.CoGroupByKey$ConstructUnionTableFn");
            Transform transform = iRVertex instanceof OperatorVertex ? ((OperatorVertex) iRVertex).getTransform() : null;
            Transform transform2 = iRVertex2 instanceof OperatorVertex ? ((OperatorVertex) iRVertex2).getTransform() : null;
            DoFn doFn = transform instanceof DoFnTransform ? ((DoFnTransform) transform).getDoFn() : null;
            return (doFn == null || !doFn.getClass().equals(cls)) ? transform instanceof FlattenTransform ? CommunicationPatternProperty.Value.ONE_TO_ONE : ((!(transform2 instanceof GBKTransform) || ((GBKTransform) transform2).getIsPartialCombining()) && !(transform2 instanceof GroupByKeyTransform)) ? transform2 instanceof CreateViewTransform ? CommunicationPatternProperty.Value.BROADCAST : CommunicationPatternProperty.Value.ONE_TO_ONE : CommunicationPatternProperty.Value.SHUFFLE : CommunicationPatternProperty.Value.SHUFFLE;
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private static Coder<?> getCoderForView(PCollectionView pCollectionView, PipelineTranslationContext pipelineTranslationContext) {
        TransformHierarchy.Node producerBeamNodeOf = pipelineTranslationContext.getProducerBeamNodeOf(pCollectionView);
        KvCoder coder = ((PCollection) producerBeamNodeOf.getOutputs().values().stream().filter(pValue -> {
            return pValue instanceof PCollection;
        }).map(pValue2 -> {
            return (PCollection) pValue2;
        }).findFirst().orElseThrow(() -> {
            return new RuntimeException(String.format("No incoming PCollection to %s", producerBeamNodeOf));
        })).getCoder();
        ViewFn viewFn = pCollectionView.getViewFn();
        if (viewFn instanceof PCollectionViews.IterableViewFn) {
            return IterableCoder.of(coder.getValueCoder());
        }
        if (viewFn instanceof PCollectionViews.ListViewFn) {
            return ListCoder.of(coder.getValueCoder());
        }
        if (viewFn instanceof PCollectionViews.MapViewFn) {
            KvCoder valueCoder = coder.getValueCoder();
            return MapCoder.of(valueCoder.getKeyCoder(), valueCoder.getValueCoder());
        }
        if (viewFn instanceof PCollectionViews.MultimapViewFn) {
            KvCoder valueCoder2 = coder.getValueCoder();
            return MapCoder.of(valueCoder2.getKeyCoder(), valueCoder2.getValueCoder());
        }
        if (viewFn instanceof PCollectionViews.SingletonViewFn) {
            return coder;
        }
        throw new UnsupportedOperationException(String.format("Unsupported viewFn %s", viewFn.getClass()));
    }
}
