package co.cask.cdap.etl.batch.spark;

import co.cask.cdap.api.TxRunnable;
import co.cask.cdap.api.data.DatasetContext;
import co.cask.cdap.api.dataset.lib.KeyValue;
import co.cask.cdap.api.metrics.Metrics;
import co.cask.cdap.api.plugin.PluginContext;
import co.cask.cdap.api.spark.JavaSparkExecutionContext;
import co.cask.cdap.api.spark.JavaSparkMain;
import co.cask.cdap.etl.api.Transform;
import co.cask.cdap.etl.api.batch.BatchAggregator;
import co.cask.cdap.etl.api.batch.SparkCompute;
import co.cask.cdap.etl.api.batch.SparkSink;
import co.cask.cdap.etl.batch.BatchPhaseSpec;
import co.cask.cdap.etl.batch.PipelinePluginInstantiator;
import co.cask.cdap.etl.common.Constants;
import co.cask.cdap.etl.common.PipelinePhase;
import co.cask.cdap.etl.common.SetMultimapCodec;
import co.cask.cdap.etl.common.TransformExecutor;
import co.cask.cdap.etl.common.TransformResponse;
import co.cask.cdap.etl.planner.StageInfo;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.SetMultimap;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;

/* loaded from: input_file:lib/cdap-etl-batch-3.4.3.jar:co/cask/cdap/etl/batch/spark/ETLSparkProgram.class */
public class ETLSparkProgram implements JavaSparkMain, TxRunnable {
    private static final Gson GSON = new GsonBuilder().registerTypeAdapter(SetMultimap.class, new SetMultimapCodec()).create();
    private transient JavaSparkContext jsc;
    private transient JavaSparkExecutionContext sec;

    /* loaded from: input_file:lib/cdap-etl-batch-3.4.3.jar:co/cask/cdap/etl/batch/spark/ETLSparkProgram$MapFunction.class */
    public static final class MapFunction<T> extends SingleTypeRDDMapFunction<Tuple2<Object, T>, KeyValue<Object, T>> {

        @Nullable
        private final String aggregatorName;
        private final boolean isBeforeBreak;

        public MapFunction(JavaSparkExecutionContext javaSparkExecutionContext, String str, String str2, boolean z) {
            super(javaSparkExecutionContext, str);
            this.aggregatorName = str2;
            this.isBeforeBreak = z;
        }

        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.SingleTypeRDDMapFunction, co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        protected TransformExecutor<KeyValue<Object, T>> initialize(BatchPhaseSpec batchPhaseSpec, PipelinePluginInstantiator pipelinePluginInstantiator) throws Exception {
            SparkTransformExecutorFactory sparkTransformExecutorFactory = new SparkTransformExecutorFactory(this.pluginContext, pipelinePluginInstantiator, this.metrics, this.logicalStartTime, this.runtimeArgs, this.isBeforeBreak);
            PipelinePhase phase = batchPhaseSpec.getPhase();
            if (this.aggregatorName != null) {
                phase = phase.subsetFrom(ImmutableSet.of(this.aggregatorName));
            }
            return sparkTransformExecutorFactory.create(phase);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.SingleTypeRDDMapFunction, co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        public KeyValue<Object, T> computeInputForExecutor(Tuple2<Object, T> tuple2) {
            return new KeyValue<>(tuple2._1(), tuple2._2());
        }
    }

    /* loaded from: input_file:lib/cdap-etl-batch-3.4.3.jar:co/cask/cdap/etl/batch/spark/ETLSparkProgram$PreGroupFunction.class */
    public static final class PreGroupFunction extends TransformExecutorFunction<Tuple2<Object, Object>, KeyValue<Object, Object>, Object, Object> {
        private final String aggregatorName;

        public PreGroupFunction(JavaSparkExecutionContext javaSparkExecutionContext, @Nullable String str) {
            super(javaSparkExecutionContext, null);
            this.aggregatorName = str;
        }

        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        protected Iterable<Tuple2<Object, Object>> getOutput(TransformResponse transformResponse) {
            ArrayList arrayList = new ArrayList();
            Iterator<Map.Entry<String, Collection<Object>>> it = transformResponse.getSinksResults().entrySet().iterator();
            while (it.hasNext()) {
                Iterator<Object> it2 = it.next().getValue().iterator();
                while (it2.hasNext()) {
                    arrayList.add((Tuple2) it2.next());
                }
            }
            return arrayList;
        }

        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        protected TransformExecutor<KeyValue<Object, Object>> initialize(BatchPhaseSpec batchPhaseSpec, PipelinePluginInstantiator pipelinePluginInstantiator) throws Exception {
            return new SparkTransformExecutorFactory(this.pluginContext, pipelinePluginInstantiator, this.metrics, this.logicalStartTime, this.runtimeArgs, true).create(batchPhaseSpec.getPhase().subsetTo(ImmutableSet.of(this.aggregatorName)));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        public KeyValue<Object, Object> computeInputForExecutor(Tuple2<Object, Object> tuple2) {
            return new KeyValue<>(tuple2._1(), tuple2._2());
        }
    }

    /* loaded from: input_file:lib/cdap-etl-batch-3.4.3.jar:co/cask/cdap/etl/batch/spark/ETLSparkProgram$SingleTypeRDDMapFunction.class */
    public static class SingleTypeRDDMapFunction<IN, EXECUTOR_IN> extends TransformExecutorFunction<IN, EXECUTOR_IN, String, Object> {
        public SingleTypeRDDMapFunction(JavaSparkExecutionContext javaSparkExecutionContext, String str) {
            super(javaSparkExecutionContext, str);
        }

        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        protected Iterable<Tuple2<String, Object>> getOutput(TransformResponse transformResponse) {
            ArrayList arrayList = new ArrayList();
            for (Map.Entry<String, Collection<Object>> entry : transformResponse.getSinksResults().entrySet()) {
                String key = entry.getKey();
                Iterator<Object> it = entry.getValue().iterator();
                while (it.hasNext()) {
                    arrayList.add(new Tuple2(key, it.next()));
                }
            }
            return arrayList;
        }

        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        protected TransformExecutor<EXECUTOR_IN> initialize(BatchPhaseSpec batchPhaseSpec, PipelinePluginInstantiator pipelinePluginInstantiator) throws Exception {
            return (TransformExecutor<EXECUTOR_IN>) new SparkTransformExecutorFactory(this.pluginContext, pipelinePluginInstantiator, this.metrics, this.logicalStartTime, this.runtimeArgs, false).create(batchPhaseSpec.getPhase());
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // co.cask.cdap.etl.batch.spark.ETLSparkProgram.TransformExecutorFunction
        protected EXECUTOR_IN computeInputForExecutor(IN in) {
            return in;
        }
    }

    /* loaded from: input_file:lib/cdap-etl-batch-3.4.3.jar:co/cask/cdap/etl/batch/spark/ETLSparkProgram$TransformExecutorFunction.class */
    public static abstract class TransformExecutorFunction<IN, EXECUTOR_IN, KEY_OUT, VAL_OUT> implements PairFlatMapFunction<IN, KEY_OUT, VAL_OUT> {
        protected final PluginContext pluginContext;
        protected final Metrics metrics;
        protected final long logicalStartTime;
        protected final Map<String, String> runtimeArgs;
        protected final String pipelineStr;
        private transient TransformExecutor<EXECUTOR_IN> transformExecutor;

        public TransformExecutorFunction(JavaSparkExecutionContext javaSparkExecutionContext, @Nullable String str) {
            this.pluginContext = javaSparkExecutionContext.getPluginContext();
            this.metrics = javaSparkExecutionContext.getMetrics();
            this.logicalStartTime = javaSparkExecutionContext.getLogicalStartTime();
            this.runtimeArgs = javaSparkExecutionContext.getRuntimeArguments();
            this.pipelineStr = str != null ? str : javaSparkExecutionContext.getSpecification().getProperty(Constants.PIPELINEID);
        }

        public Iterable<Tuple2<KEY_OUT, VAL_OUT>> call(IN in) throws Exception {
            if (this.transformExecutor == null) {
                BatchPhaseSpec batchPhaseSpec = (BatchPhaseSpec) ETLSparkProgram.GSON.fromJson(this.pipelineStr, BatchPhaseSpec.class);
                this.transformExecutor = initialize(batchPhaseSpec, new PipelinePluginInstantiator(this.pluginContext, batchPhaseSpec));
            }
            Iterable<Tuple2<KEY_OUT, VAL_OUT>> output = getOutput(this.transformExecutor.runOneIteration(computeInputForExecutor(in)));
            this.transformExecutor.resetEmitter();
            return output;
        }

        protected abstract Iterable<Tuple2<KEY_OUT, VAL_OUT>> getOutput(TransformResponse transformResponse);

        protected abstract TransformExecutor<EXECUTOR_IN> initialize(BatchPhaseSpec batchPhaseSpec, PipelinePluginInstantiator pipelinePluginInstantiator) throws Exception;

        protected abstract EXECUTOR_IN computeInputForExecutor(IN in);
    }

    public void run(JavaSparkExecutionContext javaSparkExecutionContext) throws Exception {
        this.jsc = new JavaSparkContext();
        this.sec = javaSparkExecutionContext;
        javaSparkExecutionContext.execute(this);
    }

    public void run(DatasetContext datasetContext) throws Exception {
        BatchPhaseSpec batchPhaseSpec = (BatchPhaseSpec) GSON.fromJson(this.sec.getSpecification().getProperty(Constants.PIPELINEID), BatchPhaseSpec.class);
        Set<StageInfo> stagesOfType = batchPhaseSpec.getPhase().getStagesOfType(BatchAggregator.PLUGIN_TYPE);
        String name = stagesOfType.isEmpty() ? null : stagesOfType.iterator().next().getName();
        FileInputStream fileInputStream = new FileInputStream(this.sec.getLocalizationContext().getLocalFile("ETLSpark.config"));
        Throwable th = null;
        try {
            try {
                SparkBatchSourceFactory deserialize = SparkBatchSourceFactory.deserialize(fileInputStream);
                SparkBatchSinkFactory deserialize2 = SparkBatchSinkFactory.deserialize(fileInputStream);
                Integer valueOf = Integer.valueOf(new DataInputStream(fileInputStream).readInt());
                if (fileInputStream != null) {
                    if (0 != 0) {
                        try {
                            fileInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileInputStream.close();
                    }
                }
                JavaPairRDD<String, Object> doTransform = doTransform(this.sec, this.jsc, datasetContext, batchPhaseSpec, deserialize.createRDD(this.sec, this.jsc, Object.class, Object.class), name, valueOf.intValue());
                Set<StageInfo> stagesOfType2 = batchPhaseSpec.getPhase().getStagesOfType(SparkSink.PLUGIN_TYPE);
                HashSet hashSet = new HashSet();
                Iterator<StageInfo> it = stagesOfType2.iterator();
                while (it.hasNext()) {
                    hashSet.add(it.next().getName());
                }
                for (final String str : batchPhaseSpec.getPhase().getSinks()) {
                    JavaPairRDD filter = doTransform.filter(new Function<Tuple2<String, Object>, Boolean>() { // from class: co.cask.cdap.etl.batch.spark.ETLSparkProgram.1
                        public Boolean call(Tuple2<String, Object> tuple2) throws Exception {
                            return Boolean.valueOf(((String) tuple2._1()).equals(str));
                        }
                    });
                    if (hashSet.contains(str)) {
                        ((SparkSink) this.sec.getPluginContext().newPluginInstance(str)).run(new BasicSparkExecutionPluginContext(this.sec, this.jsc, datasetContext, str), filter.values());
                    } else {
                        deserialize2.writeFromRDD(filter.flatMapToPair(new PairFlatMapFunction<Tuple2<String, Object>, Object, Object>() { // from class: co.cask.cdap.etl.batch.spark.ETLSparkProgram.2
                            public Iterable<Tuple2<Object, Object>> call(Tuple2<String, Object> tuple2) throws Exception {
                                ArrayList arrayList = new ArrayList();
                                KeyValue keyValue = (KeyValue) tuple2._2();
                                arrayList.add(new Tuple2(keyValue.getKey(), keyValue.getValue()));
                                return arrayList;
                            }
                        }), this.sec, str, Object.class, Object.class);
                    }
                }
            } finally {
            }
        } catch (Throwable th3) {
            if (fileInputStream != null) {
                if (th != null) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            throw th3;
        }
    }

    private JavaPairRDD<String, Object> doTransform(JavaSparkExecutionContext javaSparkExecutionContext, JavaSparkContext javaSparkContext, DatasetContext datasetContext, BatchPhaseSpec batchPhaseSpec, JavaPairRDD<Object, Object> javaPairRDD, String str, int i) throws Exception {
        Set<StageInfo> stagesOfType = batchPhaseSpec.getPhase().getStagesOfType(SparkCompute.PLUGIN_TYPE);
        if (stagesOfType.isEmpty()) {
            if (str == null) {
                return javaPairRDD.flatMapToPair(new MapFunction(javaSparkExecutionContext, null, null, false)).cache();
            }
            JavaPairRDD flatMapToPair = javaPairRDD.flatMapToPair(new PreGroupFunction(javaSparkExecutionContext, str));
            return (i < 0 ? flatMapToPair.groupByKey() : flatMapToPair.groupByKey(i)).flatMapToPair(new MapFunction(javaSparkExecutionContext, null, str, false)).cache();
        }
        Set<StageInfo> stagesOfType2 = batchPhaseSpec.getPhase().getStagesOfType(Transform.PLUGIN_TYPE);
        Preconditions.checkArgument(stagesOfType2.isEmpty(), "Found non-empty set of transform plugins when expecting none: %s", stagesOfType2);
        Preconditions.checkArgument(stagesOfType.size() == 1, "Expected only 1 SparkCompute: %s", stagesOfType);
        String name = ((StageInfo) Iterables.getOnlyElement(stagesOfType)).getName();
        Set<String> sources = batchPhaseSpec.getPhase().getSources();
        Preconditions.checkArgument(sources.size() == 1, "Expected only 1 source stage: %s", sources);
        Set<String> stageOutputs = batchPhaseSpec.getPhase().getStageOutputs((String) Iterables.getOnlyElement(sources));
        Preconditions.checkArgument(stageOutputs.size() == 1, "Expected only 1 stage after source stage: %s", stageOutputs);
        Preconditions.checkArgument(name.equals(Iterables.getOnlyElement(stageOutputs)), "Expected the single stage after the source stage to be the spark compute: %s", name);
        return ((SparkCompute) new PipelinePluginInstantiator(javaSparkExecutionContext.getPluginContext(), batchPhaseSpec).newPluginInstance(name)).transform(new BasicSparkExecutionPluginContext(javaSparkExecutionContext, javaSparkContext, datasetContext, name), javaPairRDD.flatMapToPair(new MapFunction(javaSparkExecutionContext, GSON.toJson(new BatchPhaseSpec(batchPhaseSpec.getPhaseName(), batchPhaseSpec.getPhase().subsetTo(ImmutableSet.of(name)), batchPhaseSpec.getResources(), batchPhaseSpec.isStageLoggingEnabled(), batchPhaseSpec.getConnectorDatasets())), null, true)).cache().values()).flatMapToPair(new SingleTypeRDDMapFunction(javaSparkExecutionContext, GSON.toJson(new BatchPhaseSpec(batchPhaseSpec.getPhaseName(), batchPhaseSpec.getPhase().subsetFrom(ImmutableSet.of(name)), batchPhaseSpec.getResources(), batchPhaseSpec.isStageLoggingEnabled(), batchPhaseSpec.getConnectorDatasets())))).cache();
    }
}
