package org.apache.beam.runners.twister2;

import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.tset.TBase;
import edu.iu.dsc.tws.api.tset.sets.TSet;
import edu.iu.dsc.tws.api.tset.sets.batch.BatchTSet;
import edu.iu.dsc.tws.tset.TBaseGraph;
import edu.iu.dsc.tws.tset.env.BatchTSetEnvironment;
import edu.iu.dsc.tws.tset.links.BaseTLink;
import edu.iu.dsc.tws.tset.sets.BaseTSet;
import edu.iu.dsc.tws.tset.sets.BuildableTSet;
import edu.iu.dsc.tws.tset.sets.batch.CachedTSet;
import edu.iu.dsc.tws.tset.sets.batch.ComputeTSet;
import edu.iu.dsc.tws.tset.sets.batch.SinkTSet;
import edu.iu.dsc.tws.tset.worker.BatchTSetIWorker;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.twister2.translators.functions.DoFnFunction;
import org.apache.beam.runners.twister2.translators.functions.Twister2SinkFunction;

/* loaded from: input_file:org/apache/beam/runners/twister2/BeamBatchWorker.class */
public class BeamBatchWorker implements Serializable, BatchTSetIWorker {
    private static final String SIDEINPUTS = "sideInputs";
    private static final String LEAVES = "leaves";
    private static final String GRAPH = "graph";
    private HashMap<String, BatchTSet<?>> sideInputDataSets;
    private Set<TSet> leaves;

    public void execute(BatchTSetEnvironment batchTSetEnvironment) {
        Config config = batchTSetEnvironment.getConfig();
        LinkedHashMap linkedHashMap = (LinkedHashMap) config.get(SIDEINPUTS);
        Set<String> set = (Set) config.get(LEAVES);
        TBaseGraph tBaseGraph = (TBaseGraph) config.get(GRAPH);
        batchTSetEnvironment.settBaseGraph(tBaseGraph);
        setupTSets(batchTSetEnvironment, linkedHashMap, set);
        resetEnv(batchTSetEnvironment, tBaseGraph);
        executePipeline(batchTSetEnvironment);
    }

    private void resetEnv(BatchTSetEnvironment batchTSetEnvironment, TBaseGraph tBaseGraph) {
        for (BaseTLink baseTLink : tBaseGraph.getNodes()) {
            if (baseTLink instanceof BaseTSet) {
                ((BaseTSet) baseTLink).setTSetEnv(batchTSetEnvironment);
            } else {
                if (!(baseTLink instanceof BaseTLink)) {
                    throw new IllegalStateException("node must be either of type BaseTSet or BaseTLink");
                }
                baseTLink.setTSetEnv(batchTSetEnvironment);
            }
        }
    }

    private void setupTSets(BatchTSetEnvironment batchTSetEnvironment, Map<String, String> map, Set<String> set) {
        this.sideInputDataSets = new LinkedHashMap();
        this.leaves = new HashSet();
        HashSet hashSet = new HashSet();
        Iterator it = batchTSetEnvironment.getGraph().getSources().iterator();
        while (it.hasNext()) {
            hashSet.add(batchTSetEnvironment.getGraph().getNodeById(((BuildableTSet) it.next()).getId()));
        }
        batchTSetEnvironment.getGraph().setSources(hashSet);
        for (Map.Entry<String, String> entry : map.entrySet()) {
            this.sideInputDataSets.put(entry.getKey(), batchTSetEnvironment.getGraph().getNodeById(entry.getValue()));
        }
        Iterator<String> it2 = set.iterator();
        while (it2.hasNext()) {
            this.leaves.add((TSet) batchTSetEnvironment.getGraph().getNodeById(it2.next()));
        }
    }

    public void executePipeline(BatchTSetEnvironment batchTSetEnvironment) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, BatchTSet<?>> entry : this.sideInputDataSets.entrySet()) {
            BatchTSet<?> value = entry.getValue();
            addInputs((BaseTSet) value, hashMap);
            hashMap.put(entry.getKey(), (CachedTSet) value.cache());
        }
        Iterator<TSet> it = this.leaves.iterator();
        while (it.hasNext()) {
            SinkTSet sink = it.next().direct().sink(new Twister2SinkFunction());
            addInputs(sink, hashMap);
            eval(batchTSetEnvironment, sink);
        }
    }

    private void addInputs(BaseTSet baseTSet, Map<String, CachedTSet> map) {
        if (map.isEmpty()) {
            return;
        }
        TBaseGraph tBaseGraph = baseTSet.getTBaseGraph();
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.add(baseTSet);
        while (!arrayDeque.isEmpty()) {
            ComputeTSet computeTSet = (TBase) arrayDeque.remove();
            arrayDeque.addAll(tBaseGraph.getPredecessors(computeTSet));
            if ((computeTSet instanceof ComputeTSet) && (computeTSet.getComputeFunc() instanceof DoFnFunction)) {
                for (String str : computeTSet.getComputeFunc().getSideInputKeys()) {
                    if (!map.containsKey(str)) {
                        throw new IllegalStateException("Side input not found for key " + str);
                    }
                    computeTSet.addInput(str, map.get(str));
                }
            }
        }
    }

    public void eval(BatchTSetEnvironment batchTSetEnvironment, SinkTSet<?> sinkTSet) {
        batchTSetEnvironment.run(sinkTSet);
    }
}
