package org.apache.beam.runners.core.construction.graph;

import com.google.auto.value.AutoValue;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.OutputDeduplicator;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ComparisonChain;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.class */
public class GreedyPipelineFuser {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) GreedyPipelineFuser.class);
    private final QueryablePipeline pipeline;
    private final FusedPipeline fusedPipeline;

    /* JADX INFO: Access modifiers changed from: package-private */
    @AutoValue
    /* loaded from: input_file:org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser$CollectionConsumer.class */
    public static abstract class CollectionConsumer implements Comparable<CollectionConsumer> {
        static CollectionConsumer of(PipelineNode.PCollectionNode pCollectionNode, PipelineNode.PTransformNode pTransformNode) {
            return new AutoValue_GreedyPipelineFuser_CollectionConsumer(pCollectionNode, pTransformNode);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract PipelineNode.PCollectionNode consumedCollection();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract PipelineNode.PTransformNode consumingTransform();

        @Override // java.lang.Comparable
        public int compareTo(CollectionConsumer collectionConsumer) {
            return ComparisonChain.start().compare(consumedCollection().getId(), collectionConsumer.consumedCollection().getId()).compare(consumingTransform().getId(), collectionConsumer.consumingTransform().getId()).result();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @AutoValue
    /* loaded from: input_file:org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser$DescendantConsumers.class */
    public static abstract class DescendantConsumers {
        static DescendantConsumers of(Set<PipelineNode.PTransformNode> set, NavigableSet<CollectionConsumer> navigableSet) {
            return new AutoValue_GreedyPipelineFuser_DescendantConsumers(set, navigableSet);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract Set<PipelineNode.PTransformNode> getUnfusedNodes();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract NavigableSet<CollectionConsumer> getFusibleConsumers();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @AutoValue
    /* loaded from: input_file:org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser$SiblingKey.class */
    public static abstract class SiblingKey {
        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract PipelineNode.PCollectionNode getInputCollection();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract RunnerApi.Environment getEnv();
    }

    private GreedyPipelineFuser(RunnerApi.Pipeline pipeline) {
        PipelineValidator.validate(pipeline);
        this.pipeline = QueryablePipeline.forPrimitivesIn(pipeline.getComponents());
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        NavigableSet<CollectionConsumer> treeSet = new TreeSet<>();
        Iterator<PipelineNode.PTransformNode> it = this.pipeline.getRootTransforms().iterator();
        while (it.hasNext()) {
            DescendantConsumers rootConsumers = getRootConsumers(it.next());
            linkedHashSet.addAll(rootConsumers.getUnfusedNodes());
            treeSet.addAll(rootConsumers.getFusibleConsumers());
        }
        this.fusedPipeline = fusePipeline(linkedHashSet, groupSiblings(treeSet));
    }

    public static FusedPipeline fuse(RunnerApi.Pipeline pipeline) {
        return new GreedyPipelineFuser(pipeline).fusedPipeline;
    }

    private FusedPipeline fusePipeline(Collection<PipelineNode.PTransformNode> collection, NavigableSet<NavigableSet<CollectionConsumer>> navigableSet) {
        HashMap hashMap = new HashMap();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet(collection);
        ArrayDeque arrayDeque = new ArrayDeque(navigableSet);
        while (!arrayDeque.isEmpty()) {
            Set set = (Set) arrayDeque.poll();
            Sets.SetView difference = Sets.difference(set, hashMap.keySet());
            Preconditions.checkState(difference.equals(set) || difference.isEmpty(), "Inconsistent collection of siblings reported for a %s. Initial attempt missed %s", PipelineNode.PCollectionNode.class.getSimpleName(), difference);
            if (difference.isEmpty()) {
                LOG.debug("Filtered out duplicate stage root {}", set);
            } else {
                ExecutableStage fuseSiblings = fuseSiblings(difference);
                Iterator<CollectionConsumer> it = difference.iterator();
                while (it.hasNext()) {
                    hashMap.put(it.next(), fuseSiblings);
                }
                linkedHashSet.add(fuseSiblings);
                Iterator<PipelineNode.PCollectionNode> it2 = fuseSiblings.getOutputPCollections().iterator();
                while (it2.hasNext()) {
                    DescendantConsumers descendantConsumers = getDescendantConsumers(it2.next());
                    linkedHashSet2.addAll(descendantConsumers.getUnfusedNodes());
                    arrayDeque.addAll(groupSiblings(descendantConsumers.getFusibleConsumers()));
                }
            }
        }
        OutputDeduplicator.DeduplicationResult ensureSingleProducer = OutputDeduplicator.ensureSingleProducer(this.pipeline, linkedHashSet, linkedHashSet2);
        return FusedPipeline.of(ensureSingleProducer.getDeduplicatedComponents(), (Set) linkedHashSet.stream().map(executableStage -> {
            return ensureSingleProducer.getDeduplicatedStages().getOrDefault(executableStage, executableStage);
        }).map(GreedyPipelineFuser::sanitizeDanglingPTransformInputs).collect(Collectors.toSet()), Sets.union(ensureSingleProducer.getIntroducedTransforms(), (Set) linkedHashSet2.stream().map(pTransformNode -> {
            return ensureSingleProducer.getDeduplicatedTransforms().getOrDefault(pTransformNode.getId(), pTransformNode);
        }).collect(Collectors.toSet())));
    }

    private DescendantConsumers getRootConsumers(PipelineNode.PTransformNode pTransformNode) {
        Preconditions.checkArgument(pTransformNode.getTransform().getInputsCount() == 0, "Transform %s is not at the root of the graph (consumes %s)", pTransformNode.getId(), pTransformNode.getTransform().getInputsMap());
        Preconditions.checkArgument(!this.pipeline.getEnvironment(pTransformNode).isPresent(), "%s requires all root nodes to be runner-implemented %s or %s primitives, but transform %s executes in environment %s", GreedyPipelineFuser.class.getSimpleName(), PTransformTranslation.IMPULSE_TRANSFORM_URN, PTransformTranslation.READ_TRANSFORM_URN, pTransformNode.getId(), this.pipeline.getEnvironment(pTransformNode));
        HashSet hashSet = new HashSet();
        hashSet.add(pTransformNode);
        TreeSet treeSet = new TreeSet();
        Iterator<PipelineNode.PCollectionNode> it = this.pipeline.getOutputPCollections(pTransformNode).iterator();
        while (it.hasNext()) {
            DescendantConsumers descendantConsumers = getDescendantConsumers(it.next());
            hashSet.addAll(descendantConsumers.getUnfusedNodes());
            treeSet.addAll(descendantConsumers.getFusibleConsumers());
        }
        return DescendantConsumers.of(hashSet, treeSet);
    }

    private DescendantConsumers getDescendantConsumers(PipelineNode.PCollectionNode pCollectionNode) {
        HashSet hashSet = new HashSet();
        TreeSet treeSet = new TreeSet();
        for (PipelineNode.PTransformNode pTransformNode : this.pipeline.getPerElementConsumers(pCollectionNode)) {
            if (this.pipeline.getEnvironment(pTransformNode).isPresent()) {
                treeSet.add(CollectionConsumer.of(pCollectionNode, pTransformNode));
            } else {
                LOG.debug("Adding {} {} to the set of runner-executed transforms", PipelineNode.PTransformNode.class.getSimpleName(), pTransformNode.getId());
                hashSet.add(pTransformNode);
                Iterator<PipelineNode.PCollectionNode> it = this.pipeline.getOutputPCollections(pTransformNode).iterator();
                while (it.hasNext()) {
                    DescendantConsumers descendantConsumers = getDescendantConsumers(it.next());
                    hashSet.addAll(descendantConsumers.getUnfusedNodes());
                    treeSet.addAll(descendantConsumers.getFusibleConsumers());
                }
            }
        }
        return DescendantConsumers.of(hashSet, treeSet);
    }

    private NavigableSet<NavigableSet<CollectionConsumer>> groupSiblings(NavigableSet<CollectionConsumer> navigableSet) {
        HashMultimap create = HashMultimap.create();
        for (CollectionConsumer collectionConsumer : navigableSet) {
            AutoValue_GreedyPipelineFuser_SiblingKey autoValue_GreedyPipelineFuser_SiblingKey = new AutoValue_GreedyPipelineFuser_SiblingKey(collectionConsumer.consumedCollection(), this.pipeline.getEnvironment(collectionConsumer.consumingTransform()).get());
            boolean z = false;
            Iterator it = create.get((HashMultimap) autoValue_GreedyPipelineFuser_SiblingKey).iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Set set = (Set) it.next();
                if (set.stream().allMatch(collectionConsumer2 -> {
                    return GreedyPCollectionFusers.isCompatible(collectionConsumer2.consumingTransform(), collectionConsumer.consumingTransform(), this.pipeline);
                })) {
                    set.add(collectionConsumer);
                    z = true;
                    break;
                }
            }
            if (!z) {
                TreeSet treeSet = new TreeSet();
                treeSet.add(collectionConsumer);
                create.put(autoValue_GreedyPipelineFuser_SiblingKey, treeSet);
            }
        }
        TreeSet treeSet2 = new TreeSet(Comparator.comparing((v0) -> {
            return v0.first();
        }));
        treeSet2.addAll(create.values());
        return treeSet2;
    }

    private ExecutableStage fuseSiblings(Set<CollectionConsumer> set) {
        return GreedyStageFuser.forGrpcPortRead(this.pipeline, set.iterator().next().consumedCollection(), (Set) set.stream().map((v0) -> {
            return v0.consumingTransform();
        }).collect(Collectors.toSet()));
    }

    private static ExecutableStage sanitizeDanglingPTransformInputs(ExecutableStage executableStage) {
        HashSet hashSet = new HashSet();
        hashSet.add(executableStage.getInputPCollection().getId());
        hashSet.addAll((Collection) executableStage.getOutputPCollections().stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet()));
        hashSet.addAll((Collection) executableStage.getSideInputs().stream().map(sideInputReference -> {
            return sideInputReference.collection().getId();
        }).collect(Collectors.toSet()));
        hashSet.addAll((Collection) executableStage.getTransforms().stream().flatMap(pTransformNode -> {
            return pTransformNode.getTransform().getOutputsMap().values().stream();
        }).collect(Collectors.toSet()));
        Set set = (Set) executableStage.getTransforms().stream().flatMap(pTransformNode2 -> {
            return pTransformNode2.getTransform().getInputsMap().values().stream();
        }).filter(str -> {
            return !hashSet.contains(str);
        }).collect(Collectors.toSet());
        ImmutableList.Builder builder = ImmutableList.builder();
        for (PipelineNode.PTransformNode pTransformNode3 : executableStage.getTransforms()) {
            RunnerApi.PTransform transform = pTransformNode3.getTransform();
            Map<String, String> map = (Map) transform.getInputsMap().entrySet().stream().filter(entry -> {
                return !set.contains(entry.getValue());
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            if (!map.equals(transform.getInputsMap())) {
                pTransformNode3 = PipelineNode.pTransform(pTransformNode3.getId(), transform.toBuilder().clearInputs().putAllInputs(map).build());
            }
            builder.add((ImmutableList.Builder) pTransformNode3);
        }
        ImmutableList build = builder.build();
        RunnerApi.Components.Builder builder2 = executableStage.getComponents().toBuilder();
        builder2.clearTransforms().putAllTransforms((Map) build.stream().collect(Collectors.toMap((v0) -> {
            return v0.getId();
        }, (v0) -> {
            return v0.getTransform();
        })));
        builder2.clearPcollections().putAllPcollections((Map) executableStage.getComponents().getPcollectionsMap().entrySet().stream().filter(entry2 -> {
            return !set.contains(entry2.getKey());
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        })));
        return ImmutableExecutableStage.of(builder2.build(), executableStage.getEnvironment(), executableStage.getInputPCollection(), executableStage.getSideInputs(), executableStage.getUserStates(), executableStage.getTimers(), build, executableStage.getOutputPCollections());
    }
}
