package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOuputCoder;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.Dataset;
import scala.Tuple2;

/* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.class */
class ParDoTranslatorBatch<InputT, OutputT> implements TransformTranslator<PTransform<PCollection<InputT>, PCollectionTuple>> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch$DoFnFilterFunction.class */
    public static class DoFnFilterFunction implements FilterFunction<Tuple2<TupleTag<?>, WindowedValue<?>>> {
        private final TupleTag<?> key;

        DoFnFilterFunction(TupleTag<?> tupleTag) {
            this.key = tupleTag;
        }

        public boolean call(Tuple2<TupleTag<?>, WindowedValue<?>> tuple2) {
            return ((TupleTag) tuple2._1).equals(this.key);
        }
    }

    @Override // org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator
    public void translateTransform(PTransform<PCollection<InputT>, PCollectionTuple> pTransform, TranslationContext translationContext) {
        String fullName = translationContext.getCurrentTransform().getFullName();
        DoFn<InputT, OutputT> doFn = getDoFn(translationContext);
        Preconditions.checkState(!DoFnSignatures.isSplittable(doFn), "Not expected to directly translate splittable DoFn, should have been overridden: %s", doFn);
        Preconditions.checkState(!DoFnSignatures.isStateful(doFn), "States and timers are not supported for the moment.");
        Preconditions.checkState(!DoFnSignatures.requiresTimeSortedInput(doFn), "@RequiresTimeSortedInput is not supported for the moment");
        DoFnSchemaInformation schemaInformation = ParDoTranslation.getSchemaInformation(translationContext.getCurrentTransform());
        PCollection input = translationContext.getInput();
        Dataset dataset = translationContext.getDataset(input);
        Map<TupleTag<?>, PCollection<?>> outputs = translationContext.getOutputs();
        TupleTag<?> tupleTag = getTupleTag(translationContext);
        ArrayList<TupleTag> arrayList = new ArrayList(outputs.keySet());
        WindowingStrategy windowingStrategy = input.getWindowingStrategy();
        Coder coder = input.getCoder();
        Coder<? extends BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();
        List<PCollectionView<?>> sideInputs = getSideInputs(translationContext);
        HashMap hashMap = new HashMap();
        for (PCollectionView<?> pCollectionView : sideInputs) {
            hashMap.put(pCollectionView, pCollectionView.getPCollection().getWindowingStrategy());
        }
        SideInputBroadcast createBroadcastSideInputs = createBroadcastSideInputs(sideInputs, translationContext);
        Map<TupleTag<?>, Coder<?>> outputCoders = translationContext.getOutputCoders();
        MetricsContainerStepMapAccumulator metricsAccumulator = MetricsAccumulator.getInstance();
        ArrayList arrayList2 = new ArrayList();
        for (TupleTag tupleTag2 : arrayList) {
            if (!tupleTag2.equals(tupleTag)) {
                arrayList2.add(tupleTag2);
            }
        }
        Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> mapPartitions = dataset.mapPartitions(new DoFnFunction(metricsAccumulator, fullName, doFn, windowingStrategy, hashMap, translationContext.getSerializableOptions(), arrayList2, tupleTag, coder, outputCoders, createBroadcastSideInputs, schemaInformation, ParDoTranslation.getSideInputMapping(translationContext.getCurrentTransform())), EncoderHelpers.fromBeamCoder(MultiOuputCoder.of(SerializableCoder.of(TupleTag.class), outputCoders, windowCoder)));
        if (outputs.entrySet().size() <= 1) {
            translationContext.putDatasetWildcard((PValue) outputs.entrySet().iterator().next().getValue(), mapPartitions.map(tuple2 -> {
                return (WindowedValue) tuple2._2;
            }, EncoderHelpers.fromBeamCoder(WindowedValue.getFullCoder(outputs.get(tupleTag).getCoder(), windowCoder))));
            return;
        }
        mapPartitions.persist();
        Iterator<Map.Entry<TupleTag<?>, PCollection<?>>> it = outputs.entrySet().iterator();
        while (it.hasNext()) {
            pruneOutputFilteredByTag(translationContext, mapPartitions, it.next(), windowCoder);
        }
    }

    private static SideInputBroadcast createBroadcastSideInputs(List<PCollectionView<?>> list, TranslationContext translationContext) {
        JavaSparkContext fromSparkContext = JavaSparkContext.fromSparkContext(translationContext.getSparkSession().sparkContext());
        SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
        for (PCollectionView<?> pCollectionView : list) {
            WindowedValue.FullWindowedValueCoder fullCoder = WindowedValue.getFullCoder(pCollectionView.getPCollection().getCoder(), pCollectionView.getPCollection().getWindowingStrategy().getWindowFn().windowCoder());
            List collectAsList = translationContext.getSideInputDataSet(pCollectionView).collectAsList();
            ArrayList arrayList = new ArrayList();
            Iterator it = collectAsList.iterator();
            while (it.hasNext()) {
                arrayList.add(CoderHelpers.toByteArray((WindowedValue) it.next(), fullCoder));
            }
            sideInputBroadcast.add(pCollectionView.getTagInternal().getId(), fromSparkContext.broadcast(arrayList), fullCoder);
        }
        return sideInputBroadcast;
    }

    private List<PCollectionView<?>> getSideInputs(TranslationContext translationContext) {
        try {
            return ParDoTranslation.getSideInputs(translationContext.getCurrentTransform());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private TupleTag<?> getTupleTag(TranslationContext translationContext) {
        try {
            return ParDoTranslation.getMainOutputTag(translationContext.getCurrentTransform());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private DoFn<InputT, OutputT> getDoFn(TranslationContext translationContext) {
        try {
            return ParDoTranslation.getDoFn(translationContext.getCurrentTransform());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void pruneOutputFilteredByTag(TranslationContext translationContext, Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> dataset, Map.Entry<TupleTag<?>, PCollection<?>> entry, Coder<? extends BoundedWindow> coder) {
        translationContext.putDatasetWildcard((PValue) entry.getValue(), dataset.filter(new DoFnFilterFunction(entry.getKey())).map(tuple2 -> {
            return (WindowedValue) tuple2._2;
        }, EncoderHelpers.fromBeamCoder(WindowedValue.getFullCoder(entry.getValue().getCoder(), coder))));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -844435198:
                if (implMethodName.equals("lambda$pruneOutputFilteredByTag$b05e36e3$1")) {
                    z = false;
                    break;
                }
                break;
            case 1074022973:
                if (implMethodName.equals("lambda$translateTransform$ba2f078e$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Lorg/apache/beam/sdk/util/WindowedValue;")) {
                    return tuple2 -> {
                        return (WindowedValue) tuple2._2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Lorg/apache/beam/sdk/util/WindowedValue;")) {
                    return tuple22 -> {
                        return (WindowedValue) tuple22._2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
