package co.cask.cdap.etl.planner;

import co.cask.cdap.etl.api.action.Action;
import co.cask.cdap.etl.common.Constants;
import co.cask.cdap.etl.common.PipelinePhase;
import co.cask.cdap.etl.proto.Connection;
import co.cask.cdap.etl.spec.PipelineSpec;
import co.cask.cdap.etl.spec.StageSpec;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

/* loaded from: input_file:lib/cdap-etl-core-3.5.0.jar:co/cask/cdap/etl/planner/PipelinePlanner.class */
public class PipelinePlanner {
    private final Set<String> reduceTypes;
    private final Set<String> isolationTypes;
    private final Set<String> supportedPluginTypes;

    public PipelinePlanner(Set<String> set, Set<String> set2, Set<String> set3) {
        this.reduceTypes = ImmutableSet.copyOf(set2);
        this.isolationTypes = ImmutableSet.copyOf(set3);
        this.supportedPluginTypes = ImmutableSet.copyOf(set);
    }

    public PipelinePlan plan(PipelineSpec pipelineSpec) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        Set<String> hashSet3 = new HashSet<>();
        Map<String, StageSpec> hashMap = new HashMap<>();
        for (StageSpec stageSpec : pipelineSpec.getStages()) {
            if (this.reduceTypes.contains(stageSpec.getPlugin().getType())) {
                hashSet.add(stageSpec.getName());
            }
            if (this.isolationTypes.contains(stageSpec.getPlugin().getType())) {
                hashSet2.add(stageSpec.getName());
            }
            if (Action.PLUGIN_TYPE.equals(stageSpec.getPlugin().getType())) {
                hashSet3.add(stageSpec.getName());
            }
            hashMap.put(stageSpec.getName(), stageSpec);
        }
        SetMultimap<String, String> create = HashMultimap.create();
        SetMultimap<String, String> create2 = HashMultimap.create();
        HashSet hashSet4 = new HashSet();
        for (Connection connection : pipelineSpec.getConnections()) {
            if (hashSet3.contains(connection.getFrom()) || hashSet3.contains(connection.getTo())) {
                if (hashSet3.contains(connection.getFrom())) {
                    create.put(connection.getFrom(), connection.getTo());
                }
                if (hashSet3.contains(connection.getTo())) {
                    create2.put(connection.getTo(), connection.getFrom());
                }
            } else {
                hashSet4.add(connection);
            }
        }
        if (hashSet4.isEmpty()) {
            Set<Connection> hashSet5 = new HashSet<>();
            Map<String, PipelinePhase> hashMap2 = new HashMap<>();
            populateActionPhases(hashMap, hashSet3, hashMap2, hashSet5, create, create2, new HashMap<>());
            return new PipelinePlan(hashMap2, hashSet5);
        }
        ConnectorDag build = ConnectorDag.builder().addConnections(hashSet4).addReduceNodes(hashSet).addIsolationNodes(hashSet2).build();
        build.insertConnectors();
        Set<String> connectors = build.getConnectors();
        Map<String, Dag> hashMap3 = new HashMap<>();
        for (Dag dag : build.split()) {
            hashMap3.put(getPhaseName(dag.getSources(), dag.getSinks()), dag);
        }
        Set<Connection> hashSet6 = new HashSet<>();
        for (Map.Entry<String, Dag> entry : hashMap3.entrySet()) {
            String key = entry.getKey();
            Dag value = entry.getValue();
            for (Map.Entry<String, Dag> entry2 : hashMap3.entrySet()) {
                String key2 = entry2.getKey();
                Dag value2 = entry2.getValue();
                if (!key.equals(key2) && Sets.intersection(value.getSinks(), value2.getSources()).size() > 0) {
                    hashSet6.add(new Connection(key, key2));
                }
            }
        }
        HashMap hashMap4 = new HashMap();
        for (Map.Entry<String, Dag> entry3 : hashMap3.entrySet()) {
            hashMap4.put(entry3.getKey(), dagToPipeline(entry3.getValue(), connectors, hashMap));
        }
        populateActionPhases(hashMap, hashSet3, hashMap4, hashSet6, create, create2, hashMap3);
        return new PipelinePlan(hashMap4, hashSet6);
    }

    private void populateActionPhases(Map<String, StageSpec> map, Set<String> set, Map<String, PipelinePhase> map2, Set<Connection> set2, SetMultimap<String, String> setMultimap, SetMultimap<String, String> setMultimap2, Map<String, Dag> map3) {
        for (String str : set) {
            StageSpec stageSpec = map.get(str);
            map2.put(str, PipelinePhase.builder(this.supportedPluginTypes).addStage(StageInfo.builder(str, Action.PLUGIN_TYPE).addInputs(stageSpec.getInputs()).addInputSchemas(stageSpec.getInputSchemas()).addOutputs(stageSpec.getOutputs()).setOutputSchema(stageSpec.getOutputSchema()).setErrorDatasetName(stageSpec.getErrorDatasetName()).build()).build());
        }
        for (String str2 : setMultimap.keySet()) {
            for (Map.Entry<String, Dag> entry : map3.entrySet()) {
                if (Sets.intersection(setMultimap.get(str2), entry.getValue().getSources()).size() > 0) {
                    set2.add(new Connection(str2, entry.getKey()));
                }
            }
            for (String str3 : setMultimap.get(str2)) {
                if (set.contains(str3)) {
                    set2.add(new Connection(str2, str3));
                }
            }
        }
        for (String str4 : setMultimap2.keySet()) {
            for (Map.Entry<String, Dag> entry2 : map3.entrySet()) {
                if (Sets.intersection(setMultimap2.get(str4), entry2.getValue().getSinks()).size() > 0) {
                    set2.add(new Connection(entry2.getKey(), str4));
                }
            }
        }
    }

    private PipelinePhase dagToPipeline(Dag dag, Set<String> set, Map<String, StageSpec> map) {
        PipelinePhase.Builder builder = PipelinePhase.builder(this.supportedPluginTypes);
        for (String str : dag.getTopologicalOrder()) {
            Set<String> nodeOutputs = dag.getNodeOutputs(str);
            if (!nodeOutputs.isEmpty()) {
                builder.addConnections(str, nodeOutputs);
            }
            if (set.contains(str)) {
                builder.addStage(StageInfo.builder(str, Constants.CONNECTOR_TYPE).build());
            } else {
                StageSpec stageSpec = map.get(str);
                builder.addStage(StageInfo.builder(str, stageSpec.getPlugin().getType()).addInputs(stageSpec.getInputs()).addInputSchemas(stageSpec.getInputSchemas()).addOutputs(stageSpec.getOutputs()).setOutputSchema(stageSpec.getOutputSchema()).setErrorDatasetName(stageSpec.getErrorDatasetName()).build());
            }
        }
        return builder.build();
    }

    @VisibleForTesting
    static String getPhaseName(Set<String> set, Set<String> set2) {
        return Joiner.on('.').join(new TreeSet(set)) + ".to." + Joiner.on('.').join(new TreeSet(set2));
    }
}
