package co.cask.cdap.etl.spark;

import co.cask.cdap.api.plugin.PluginContext;
import co.cask.cdap.api.spark.JavaSparkExecutionContext;
import co.cask.cdap.etl.api.Alert;
import co.cask.cdap.etl.api.AlertPublisher;
import co.cask.cdap.etl.api.ErrorRecord;
import co.cask.cdap.etl.api.ErrorTransform;
import co.cask.cdap.etl.api.JoinElement;
import co.cask.cdap.etl.api.SplitterTransform;
import co.cask.cdap.etl.api.Transform;
import co.cask.cdap.etl.api.batch.BatchAggregator;
import co.cask.cdap.etl.api.batch.BatchJoiner;
import co.cask.cdap.etl.api.batch.BatchJoinerRuntimeContext;
import co.cask.cdap.etl.api.batch.BatchSink;
import co.cask.cdap.etl.api.batch.SparkCompute;
import co.cask.cdap.etl.api.batch.SparkSink;
import co.cask.cdap.etl.api.streaming.Windower;
import co.cask.cdap.etl.common.BasicArguments;
import co.cask.cdap.etl.common.Constants;
import co.cask.cdap.etl.common.DefaultMacroEvaluator;
import co.cask.cdap.etl.common.NoopStageStatisticsCollector;
import co.cask.cdap.etl.common.PipelinePhase;
import co.cask.cdap.etl.common.RecordInfo;
import co.cask.cdap.etl.common.StageStatisticsCollector;
import co.cask.cdap.etl.spark.function.AlertPassFilter;
import co.cask.cdap.etl.spark.function.BatchSinkFunction;
import co.cask.cdap.etl.spark.function.ErrorPassFilter;
import co.cask.cdap.etl.spark.function.ErrorTransformFunction;
import co.cask.cdap.etl.spark.function.InitialJoinFunction;
import co.cask.cdap.etl.spark.function.JoinFlattenFunction;
import co.cask.cdap.etl.spark.function.LeftJoinFlattenFunction;
import co.cask.cdap.etl.spark.function.OuterJoinFlattenFunction;
import co.cask.cdap.etl.spark.function.OutputPassFilter;
import co.cask.cdap.etl.spark.function.PluginFunctionContext;
import co.cask.cdap.etl.spec.StageSpec;
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:lib/hydrator-spark-core2_2.11-4.3.5.jar:co/cask/cdap/etl/spark/SparkPipelineRunner.class */
public abstract class SparkPipelineRunner {
    private static final Logger LOG = LoggerFactory.getLogger(SparkPipelineRunner.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/hydrator-spark-core2_2.11-4.3.5.jar:co/cask/cdap/etl/spark/SparkPipelineRunner$EmittedRecords.class */
    public static class EmittedRecords {
        private final Map<String, SparkCollection<Object>> outputPortRecords;
        private final SparkCollection<Object> outputRecords;
        private final SparkCollection<ErrorRecord<Object>> errorRecords;
        private final SparkCollection<Alert> alertRecords;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:lib/hydrator-spark-core2_2.11-4.3.5.jar:co/cask/cdap/etl/spark/SparkPipelineRunner$EmittedRecords$Builder.class */
        public static class Builder {
            private Map<String, SparkCollection<Object>> outputPortRecords;
            private SparkCollection<Object> outputRecords;
            private SparkCollection<ErrorRecord<Object>> errorRecords;
            private SparkCollection<Alert> alertRecords;

            private Builder() {
                this.outputPortRecords = new HashMap();
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Builder addPort(String str, SparkCollection<Object> sparkCollection) {
                this.outputPortRecords.put(str, sparkCollection);
                return this;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Builder setOutput(SparkCollection<Object> sparkCollection) {
                this.outputRecords = sparkCollection;
                return this;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Builder setErrors(SparkCollection<ErrorRecord<Object>> sparkCollection) {
                this.errorRecords = sparkCollection;
                return this;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public Builder setAlerts(SparkCollection<Alert> sparkCollection) {
                this.alertRecords = sparkCollection;
                return this;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public EmittedRecords build() {
                return new EmittedRecords(this.outputPortRecords, this.outputRecords, this.errorRecords, this.alertRecords);
            }
        }

        private EmittedRecords(Map<String, SparkCollection<Object>> map, SparkCollection<Object> sparkCollection, SparkCollection<ErrorRecord<Object>> sparkCollection2, SparkCollection<Alert> sparkCollection3) {
            this.outputPortRecords = map;
            this.outputRecords = sparkCollection;
            this.errorRecords = sparkCollection2;
            this.alertRecords = sparkCollection3;
        }

        private static Builder builder() {
            return new Builder();
        }

        static /* synthetic */ Builder access$000() {
            return builder();
        }
    }

    protected abstract SparkCollection<RecordInfo<Object>> getSource(StageSpec stageSpec, StageStatisticsCollector stageStatisticsCollector) throws Exception;

    protected abstract SparkPairCollection<Object, Object> addJoinKey(StageSpec stageSpec, String str, SparkCollection<Object> sparkCollection, StageStatisticsCollector stageStatisticsCollector) throws Exception;

    protected abstract SparkCollection<Object> mergeJoinResults(StageSpec stageSpec, SparkPairCollection<Object, List<JoinElement<Object>>> sparkPairCollection, StageStatisticsCollector stageStatisticsCollector) throws Exception;

    public void runPipeline(PipelinePhase pipelinePhase, String str, JavaSparkExecutionContext javaSparkExecutionContext, Map<String, Integer> map, PluginContext pluginContext, Map<String, StageStatisticsCollector> map2) throws Exception {
        SparkPairCollection<Object, List<JoinElement<Object>>> mapValues;
        DefaultMacroEvaluator defaultMacroEvaluator = new DefaultMacroEvaluator(new BasicArguments(javaSparkExecutionContext), javaSparkExecutionContext.getLogicalStartTime(), javaSparkExecutionContext, javaSparkExecutionContext.getNamespace());
        HashMap hashMap = new HashMap();
        if (pipelinePhase.getDag() == null) {
            throw new IllegalStateException("Pipeline phase has no connections.");
        }
        ArrayList arrayList = new ArrayList();
        for (String str2 : pipelinePhase.getDag().getTopologicalOrder()) {
            StageSpec stage = pipelinePhase.getStage(str2);
            String pluginType = stage.getPluginType();
            EmittedRecords.Builder access$000 = EmittedRecords.access$000();
            boolean z = false;
            boolean z2 = false;
            Iterator<String> it = pipelinePhase.getStageOutputs(stage.getName()).iterator();
            while (it.hasNext()) {
                String pluginType2 = pipelinePhase.getStage(it.next()).getPluginType();
                if (ErrorTransform.PLUGIN_TYPE.equals(pluginType2)) {
                    z = true;
                } else if (AlertPublisher.PLUGIN_TYPE.equals(pluginType2)) {
                    z2 = true;
                }
            }
            SparkCollection sparkCollection = null;
            HashMap hashMap2 = new HashMap();
            Set<String> stageInputs = pipelinePhase.getStageInputs(str2);
            for (String str3 : stageInputs) {
                StageSpec stage2 = pipelinePhase.getStage(str3);
                if (stage2 != null) {
                    String str4 = null;
                    if (!Constants.Connector.PLUGIN_TYPE.equals(stage2.getPluginType()) && !Constants.Connector.PLUGIN_TYPE.equals(pluginType)) {
                        str4 = stage2.getOutputPorts().get(str2).getPort();
                    }
                    hashMap2.put(str3, str4 == null ? ((EmittedRecords) hashMap.get(str3)).outputRecords : (SparkCollection) ((EmittedRecords) hashMap.get(str3)).outputPortRecords.get(str4));
                }
            }
            if (!hashMap2.isEmpty()) {
                Iterator it2 = hashMap2.values().iterator();
                SparkCollection sparkCollection2 = (SparkCollection) it2.next();
                while (true) {
                    sparkCollection = sparkCollection2;
                    if (BatchJoiner.PLUGIN_TYPE.equals(pluginType) || ErrorTransform.PLUGIN_TYPE.equals(pluginType) || !it2.hasNext()) {
                        break;
                    } else {
                        sparkCollection2 = sparkCollection.union((SparkCollection) it2.next());
                    }
                }
            }
            boolean z3 = Constants.Connector.PLUGIN_TYPE.equals(pluginType) && pipelinePhase.getSources().contains(str2);
            boolean z4 = Constants.Connector.PLUGIN_TYPE.equals(pluginType) && pipelinePhase.getSinks().contains(str2);
            StageStatisticsCollector noopStageStatisticsCollector = map2.get(str2) == null ? new NoopStageStatisticsCollector() : map2.get(str2);
            PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stage, javaSparkExecutionContext, noopStageStatisticsCollector);
            if (sparkCollection == null) {
                if (!str.equals(pluginType) && !z3) {
                    throw new IllegalStateException(String.format("Stage '%s' has no input and is not a source.", str2));
                }
                access$000 = addEmitted(access$000, pipelinePhase, stage, getSource(stage, noopStageStatisticsCollector), z, z2);
            } else if (BatchSink.PLUGIN_TYPE.equals(pluginType) || z4) {
                arrayList.add(sparkCollection.createStoreTask(stage, Compat.convert(new BatchSinkFunction(pluginFunctionContext))));
            } else if (Transform.PLUGIN_TYPE.equals(pluginType)) {
                access$000 = addEmitted(access$000, pipelinePhase, stage, sparkCollection.transform(stage, noopStageStatisticsCollector), z, z2);
            } else if (SplitterTransform.PLUGIN_TYPE.equals(pluginType)) {
                access$000 = addEmitted(access$000, pipelinePhase, stage, sparkCollection.multiOutputTransform(stage, noopStageStatisticsCollector), z, z2);
            } else if (ErrorTransform.PLUGIN_TYPE.equals(pluginType)) {
                SparkCollection sparkCollection3 = null;
                Iterator<String> it3 = stageInputs.iterator();
                while (it3.hasNext()) {
                    SparkCollection sparkCollection4 = ((EmittedRecords) hashMap.get(it3.next())).errorRecords;
                    if (sparkCollection4 != null) {
                        sparkCollection3 = sparkCollection3 == null ? sparkCollection4 : sparkCollection3.union(sparkCollection4);
                    }
                }
                if (sparkCollection3 != null) {
                    access$000 = addEmitted(access$000, pipelinePhase, stage, sparkCollection3.flatMap(stage, Compat.convert(new ErrorTransformFunction(pluginFunctionContext))), z, z2);
                }
            } else if (SparkCompute.PLUGIN_TYPE.equals(pluginType)) {
                access$000 = access$000.setOutput(sparkCollection.compute(stage, (SparkCompute) pluginContext.newPluginInstance(str2, defaultMacroEvaluator)));
            } else if (SparkSink.PLUGIN_TYPE.equals(pluginType)) {
                arrayList.add(sparkCollection.createStoreTask(stage, (SparkSink) pluginContext.newPluginInstance(str2, defaultMacroEvaluator)));
            } else if (BatchAggregator.PLUGIN_TYPE.equals(pluginType)) {
                access$000 = addEmitted(access$000, pipelinePhase, stage, sparkCollection.aggregate(stage, map.get(str2), noopStageStatisticsCollector), z, z2);
            } else if (BatchJoiner.PLUGIN_TYPE.equals(pluginType)) {
                BatchJoiner batchJoiner = (BatchJoiner) pluginContext.newPluginInstance(str2, defaultMacroEvaluator);
                batchJoiner.initialize((BatchJoinerRuntimeContext) pluginFunctionContext.createBatchRuntimeContext());
                HashMap hashMap3 = new HashMap();
                for (Map.Entry entry : hashMap2.entrySet()) {
                    String str5 = (String) entry.getKey();
                    hashMap3.put(str5, addJoinKey(stage, str5, (SparkCollection) entry.getValue(), noopStageStatisticsCollector));
                }
                HashSet<String> hashSet = new HashSet();
                hashSet.addAll(hashMap2.keySet());
                Integer num = map.get(str2);
                SparkPairCollection<Object, List<JoinElement<Object>>> sparkPairCollection = null;
                for (String str6 : batchJoiner.getJoinConfig().getRequiredInputs()) {
                    SparkPairCollection<Object, T> sparkPairCollection2 = (SparkPairCollection) hashMap3.get(str6);
                    if (sparkPairCollection == null) {
                        mapValues = sparkPairCollection2.mapValues(new InitialJoinFunction(str6));
                    } else {
                        JoinFlattenFunction joinFlattenFunction = new JoinFlattenFunction(str6);
                        mapValues = num == null ? sparkPairCollection.join(sparkPairCollection2).mapValues(joinFlattenFunction) : sparkPairCollection.join(sparkPairCollection2, num.intValue()).mapValues(joinFlattenFunction);
                    }
                    sparkPairCollection = mapValues;
                    hashSet.remove(str6);
                }
                boolean z5 = sparkPairCollection == null;
                for (String str7 : hashSet) {
                    SparkPairCollection<Object, T> sparkPairCollection3 = (SparkPairCollection) hashMap3.get(str7);
                    if (sparkPairCollection == null) {
                        sparkPairCollection = sparkPairCollection3.mapValues(new InitialJoinFunction(str7));
                    } else if (z5) {
                        OuterJoinFlattenFunction outerJoinFlattenFunction = new OuterJoinFlattenFunction(str7);
                        sparkPairCollection = num == null ? sparkPairCollection.fullOuterJoin(sparkPairCollection3).mapValues(outerJoinFlattenFunction) : sparkPairCollection.fullOuterJoin(sparkPairCollection3, num.intValue()).mapValues(outerJoinFlattenFunction);
                    } else {
                        LeftJoinFlattenFunction leftJoinFlattenFunction = new LeftJoinFlattenFunction(str7);
                        sparkPairCollection = num == null ? sparkPairCollection.leftOuterJoin(sparkPairCollection3).mapValues(leftJoinFlattenFunction) : sparkPairCollection.leftOuterJoin(sparkPairCollection3, num.intValue()).mapValues(leftJoinFlattenFunction);
                    }
                }
                if (sparkPairCollection == null) {
                    throw new IllegalStateException("There are no inputs into join stage " + str2);
                }
                access$000 = access$000.setOutput(mergeJoinResults(stage, sparkPairCollection, noopStageStatisticsCollector).cache());
            } else if (Windower.PLUGIN_TYPE.equals(pluginType)) {
                access$000 = access$000.setOutput(sparkCollection.window(stage, (Windower) pluginContext.newPluginInstance(str2, defaultMacroEvaluator)));
            } else {
                if (!AlertPublisher.PLUGIN_TYPE.equals(pluginType)) {
                    throw new IllegalStateException(String.format("Stage %s is of unsupported plugin type %s.", str2, pluginType));
                }
                SparkCollection sparkCollection5 = null;
                Iterator<String> it4 = stageInputs.iterator();
                while (it4.hasNext()) {
                    SparkCollection sparkCollection6 = ((EmittedRecords) hashMap.get(it4.next())).alertRecords;
                    if (sparkCollection6 != null) {
                        sparkCollection5 = sparkCollection5 == null ? sparkCollection6 : sparkCollection5.union(sparkCollection6);
                    }
                }
                if (sparkCollection5 != null) {
                    sparkCollection5.publishAlerts(stage, noopStageStatisticsCollector);
                }
            }
            hashMap.put(str2, access$000.build());
        }
        ArrayList arrayList2 = new ArrayList(arrayList.size());
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(arrayList.size(), new ThreadFactoryBuilder().setNameFormat("pipeline-sink-task").build());
        Iterator it5 = arrayList.iterator();
        while (it5.hasNext()) {
            arrayList2.add(newFixedThreadPool.submit((Runnable) it5.next()));
        }
        Throwable th = null;
        arrayList2.iterator();
        Iterator it6 = arrayList2.iterator();
        while (it6.hasNext()) {
            try {
                ((Future) it6.next()).get();
            } catch (InterruptedException e) {
            } catch (ExecutionException e2) {
                th = e2.getCause();
            }
        }
        newFixedThreadPool.shutdownNow();
        if (th != null) {
            Throwables.propagate(th);
        }
    }

    private boolean shouldCache(PipelinePhase pipelinePhase, StageSpec stageSpec) {
        Set<String> stageOutputs = pipelinePhase.getStageOutputs(stageSpec.getName());
        if (stageOutputs.size() > 1) {
            return true;
        }
        for (String str : stageOutputs) {
            pipelinePhase.getStage(str);
            if (pipelinePhase.getStageInputs(str).size() > 1) {
                return true;
            }
        }
        return false;
    }

    private EmittedRecords.Builder addEmitted(EmittedRecords.Builder builder, PipelinePhase pipelinePhase, StageSpec stageSpec, SparkCollection<RecordInfo<Object>> sparkCollection, boolean z, boolean z2) {
        if (z || z2 || stageSpec.getOutputPorts().size() > 1) {
            sparkCollection = sparkCollection.cache();
        }
        boolean shouldCache = shouldCache(pipelinePhase, stageSpec);
        if (z) {
            SparkCollection flatMap = sparkCollection.flatMap(stageSpec, Compat.convert(new ErrorPassFilter()));
            if (shouldCache) {
                flatMap = flatMap.cache();
            }
            builder.setErrors(flatMap);
        }
        if (z2) {
            SparkCollection flatMap2 = sparkCollection.flatMap(stageSpec, Compat.convert(new AlertPassFilter()));
            if (shouldCache) {
                flatMap2 = flatMap2.cache();
            }
            builder.setAlerts(flatMap2);
        }
        if (SplitterTransform.PLUGIN_TYPE.equals(stageSpec.getPluginType())) {
            Iterator<StageSpec.Port> it = stageSpec.getOutputPorts().values().iterator();
            while (it.hasNext()) {
                String port = it.next().getPort();
                SparkCollection flatMap3 = sparkCollection.flatMap(stageSpec, Compat.convert(new OutputPassFilter(port)));
                if (shouldCache) {
                    flatMap3 = flatMap3.cache();
                }
                builder.addPort(port, flatMap3);
            }
        } else {
            SparkCollection flatMap4 = sparkCollection.flatMap(stageSpec, Compat.convert(new OutputPassFilter()));
            if (shouldCache) {
                flatMap4 = flatMap4.cache();
            }
            builder.setOutput(flatMap4);
        }
        return builder;
    }
}
