package co.cask.cdap.etl.spark.batch;

import co.cask.cdap.api.data.DatasetContext;
import co.cask.cdap.api.spark.JavaSparkExecutionContext;
import co.cask.cdap.etl.api.AlertPublisher;
import co.cask.cdap.etl.api.AlertPublisherContext;
import co.cask.cdap.etl.api.batch.SparkCompute;
import co.cask.cdap.etl.api.batch.SparkSink;
import co.cask.cdap.etl.api.streaming.Windower;
import co.cask.cdap.etl.common.Constants;
import co.cask.cdap.etl.common.DefaultAlertPublisherContext;
import co.cask.cdap.etl.common.DefaultStageMetrics;
import co.cask.cdap.etl.common.RecordInfo;
import co.cask.cdap.etl.common.StageStatisticsCollector;
import co.cask.cdap.etl.common.TrackedIterator;
import co.cask.cdap.etl.spark.Compat;
import co.cask.cdap.etl.spark.SparkCollection;
import co.cask.cdap.etl.spark.SparkPairCollection;
import co.cask.cdap.etl.spark.SparkPipelineRuntime;
import co.cask.cdap.etl.spark.function.AggregatorAggregateFunction;
import co.cask.cdap.etl.spark.function.AggregatorGroupByFunction;
import co.cask.cdap.etl.spark.function.CountingFunction;
import co.cask.cdap.etl.spark.function.MultiOutputTransformFunction;
import co.cask.cdap.etl.spark.function.PluginFunctionContext;
import co.cask.cdap.etl.spark.function.TransformFunction;
import co.cask.cdap.etl.spec.StageSpec;
import com.google.gson.Gson;
import javax.annotation.Nullable;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;

/* loaded from: input_file:lib/hydrator-spark-core-4.3.3.jar:co/cask/cdap/etl/spark/batch/RDDCollection.class */
public class RDDCollection<T> implements SparkCollection<T> {
    private static final Gson GSON = new Gson();
    private final JavaSparkExecutionContext sec;
    private final JavaSparkContext jsc;
    private final DatasetContext datasetContext;
    private final SparkBatchSinkFactory sinkFactory;
    private final JavaRDD<T> rdd;

    public RDDCollection(JavaSparkExecutionContext javaSparkExecutionContext, JavaSparkContext javaSparkContext, DatasetContext datasetContext, SparkBatchSinkFactory sparkBatchSinkFactory, JavaRDD<T> javaRDD) {
        this.sec = javaSparkExecutionContext;
        this.jsc = javaSparkContext;
        this.datasetContext = datasetContext;
        this.sinkFactory = sparkBatchSinkFactory;
        this.rdd = javaRDD;
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public JavaRDD<T> getUnderlying() {
        return this.rdd;
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public SparkCollection<T> cache() {
        return wrap(this.rdd.cache());
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public SparkCollection<T> union(SparkCollection<T> sparkCollection) {
        return wrap(this.rdd.union((JavaRDD) sparkCollection.getUnderlying()));
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public SparkCollection<RecordInfo<Object>> transform(StageSpec stageSpec, StageStatisticsCollector stageStatisticsCollector) {
        return wrap(this.rdd.flatMap(Compat.convert(new TransformFunction(new PluginFunctionContext(stageSpec, this.sec, stageStatisticsCollector)))));
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public SparkCollection<RecordInfo<Object>> multiOutputTransform(StageSpec stageSpec, StageStatisticsCollector stageStatisticsCollector) {
        return wrap(this.rdd.flatMap(Compat.convert(new MultiOutputTransformFunction(new PluginFunctionContext(stageSpec, this.sec, stageStatisticsCollector)))));
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public <U> SparkCollection<U> flatMap(StageSpec stageSpec, FlatMapFunction<T, U> flatMapFunction) {
        return wrap(this.rdd.flatMap(flatMapFunction));
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public SparkCollection<RecordInfo<Object>> aggregate(StageSpec stageSpec, @Nullable Integer num, StageStatisticsCollector stageStatisticsCollector) {
        PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, this.sec, stageStatisticsCollector);
        JavaPairRDD flatMapToPair = this.rdd.flatMapToPair(Compat.convert(new AggregatorGroupByFunction(pluginFunctionContext)));
        return wrap((num == null ? flatMapToPair.groupByKey() : flatMapToPair.groupByKey(num.intValue())).flatMap(Compat.convert(new AggregatorAggregateFunction(pluginFunctionContext))));
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public <K, V> SparkPairCollection<K, V> flatMapToPair(PairFlatMapFunction<T, K, V> pairFlatMapFunction) {
        return new PairRDDCollection(this.sec, this.jsc, this.datasetContext, this.sinkFactory, this.rdd.flatMapToPair(pairFlatMapFunction));
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public <U> SparkCollection<U> compute(StageSpec stageSpec, SparkCompute<T, U> sparkCompute) throws Exception {
        String name = stageSpec.getName();
        BasicSparkExecutionPluginContext basicSparkExecutionPluginContext = new BasicSparkExecutionPluginContext(this.sec, this.jsc, this.datasetContext, new SparkPipelineRuntime(this.sec), stageSpec);
        sparkCompute.initialize(basicSparkExecutionPluginContext);
        return wrap(sparkCompute.transform(basicSparkExecutionPluginContext, this.rdd.map(new CountingFunction(name, this.sec.getMetrics(), Constants.Metrics.RECORDS_IN, null)).cache()).map(new CountingFunction(name, this.sec.getMetrics(), Constants.Metrics.RECORDS_OUT, this.sec.getDataTracer(name))));
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public void store(StageSpec stageSpec, PairFlatMapFunction<T, Object, Object> pairFlatMapFunction) {
        this.sinkFactory.writeFromRDD(this.rdd.flatMapToPair(pairFlatMapFunction), this.sec, stageSpec.getName(), Object.class, Object.class);
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public void store(StageSpec stageSpec, SparkSink<T> sparkSink) throws Exception {
        String name = stageSpec.getName();
        sparkSink.run(new BasicSparkExecutionPluginContext(this.sec, this.jsc, this.datasetContext, new SparkPipelineRuntime(this.sec), stageSpec), this.rdd.map(new CountingFunction(name, this.sec.getMetrics(), Constants.Metrics.RECORDS_IN, null)).cache());
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public void publishAlerts(StageSpec stageSpec, StageStatisticsCollector stageStatisticsCollector) throws Exception {
        AlertPublisher alertPublisher = (AlertPublisher) new PluginFunctionContext(stageSpec, this.sec, stageStatisticsCollector).createPlugin();
        alertPublisher.initialize((AlertPublisherContext) new DefaultAlertPublisherContext(new SparkPipelineRuntime(this.sec), stageSpec, this.sec.getMessagingContext(), this.sec.getAdmin()));
        alertPublisher.publish(new TrackedIterator(this.rdd.collect().iterator(), new DefaultStageMetrics(this.sec.getMetrics(), stageSpec.getName()), Constants.Metrics.RECORDS_IN));
        alertPublisher.destroy();
    }

    @Override // co.cask.cdap.etl.spark.SparkCollection
    public SparkCollection<T> window(StageSpec stageSpec, Windower windower) {
        throw new UnsupportedOperationException("Windowing is not supported on RDDs.");
    }

    private <U> RDDCollection<U> wrap(JavaRDD<U> javaRDD) {
        return new RDDCollection<>(this.sec, this.jsc, this.datasetContext, this.sinkFactory, javaRDD);
    }
}
