package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.functions;
import org.apache.spark.storage.StorageLevel;
import scala.Tuple2;
import scala.collection.TraversableOnce;
import scala.reflect.ClassTag;

/* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.class */
class ParDoTranslatorBatch<InputT, OutputT> extends TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>> {
    private static final ClassTag<WindowedValue<Object>> WINDOWED_VALUE_CTAG = ClassTag.apply(WindowedValue.class);
    private static final ClassTag<Tuple2<Integer, WindowedValue<Object>>> TUPLE2_CTAG = ClassTag.apply(Tuple2.class);

    @Override // org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator
    public boolean canTranslate(ParDo.MultiOutput<InputT, OutputT> multiOutput) {
        DoFn fn = multiOutput.getFn();
        DoFnSignature signatureForDoFn = DoFnSignatures.signatureForDoFn(fn);
        Preconditions.checkState(!signatureForDoFn.processElement().isSplittable(), "Not expected to directly translate splittable DoFn, should have been overridden: %s", fn);
        Preconditions.checkState((signatureForDoFn.usesState() || signatureForDoFn.usesTimers()) ? false : true, "States and timers are not supported for the moment.");
        Preconditions.checkState(signatureForDoFn.onWindowExpiration() == null, "onWindowExpiration is not supported: %s", fn);
        Preconditions.checkState(!signatureForDoFn.processElement().requiresTimeSortedInput(), "@RequiresTimeSortedInput is not supported for the moment");
        SparkSideInputReader.validateMaterializations(multiOutput.getSideInputs().values());
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator
    public void translate(ParDo.MultiOutput<InputT, OutputT> multiOutput, TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>>.Context context) throws IOException {
        PCollection<? extends InputT> input = context.getInput();
        Dataset<WindowedValue<T>> dataset = context.getDataset(input);
        SideInputReader createSideInputReader = createSideInputReader(multiOutput.getSideInputs().values(), context);
        MetricsAccumulator metricsAccumulator = MetricsAccumulator.getInstance(context.getSparkSession());
        TupleTag<T> mainOutputTag = multiOutput.getMainOutputTag();
        Map<TupleTag<?>, PCollection<?>> filterEntries = Maps.filterEntries(context.getOutputs(), entry -> {
            return entry != null && (((TupleTag) entry.getKey()).equals(mainOutputTag) || !context.isLeave((PCollection) entry.getValue()));
        });
        if (filterEntries.size() <= 1) {
            PCollection<T> output = context.getOutput(mainOutputTag);
            context.putDataset(output, dataset.mapPartitions(DoFnPartitionIteratorFactory.singleOutput(context.getCurrentTransform(), context.getOptionsSupplier(), input, createSideInputReader, metricsAccumulator), context.windowedEncoder(output.getCoder())));
            return;
        }
        Map<String, Integer> tagsColumnIndex = tagsColumnIndex(filterEntries.keySet());
        List<Encoder<WindowedValue<Object>>> createEncoders = createEncoders(filterEntries, tagsColumnIndex, context);
        DoFnPartitionIteratorFactory multiOutput2 = DoFnPartitionIteratorFactory.multiOutput(context.getCurrentTransform(), context.getOptionsSupplier(), input, createSideInputReader, metricsAccumulator, tagsColumnIndex);
        StorageLevel fromString = StorageLevel.fromString(((SparkCommonPipelineOptions) context.getOptions().as(SparkCommonPipelineOptions.class)).getStorageLevel());
        if (StorageLevel.MEMORY_ONLY().equals(fromString)) {
            RDD mapPartitions = dataset.rdd().mapPartitions(multiOutput2, false, TUPLE2_CTAG);
            mapPartitions.persist();
            for (TupleTag<?> tupleTag : filterEntries.keySet()) {
                int intValue = ((Integer) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(tagsColumnIndex.get(tupleTag.getId()), "Unknown tag")).intValue();
                context.putDataset(context.getOutput(tupleTag), context.getSparkSession().createDataset(mapPartitions.flatMap(selectByColumnIdx(intValue), WINDOWED_VALUE_CTAG), createEncoders.get(intValue)), false);
            }
            return;
        }
        Dataset mapPartitions2 = dataset.mapPartitions(multiOutput2, EncoderHelpers.oneOfEncoder(createEncoders));
        mapPartitions2.persist(fromString);
        for (TupleTag<?> tupleTag2 : filterEntries.keySet()) {
            int intValue2 = ((Integer) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(tagsColumnIndex.get(tupleTag2.getId()), "Unknown tag")).intValue();
            TypedColumn as = functions.col(Integer.toString(intValue2)).as(createEncoders.get(intValue2));
            context.putDataset(context.getOutput(tupleTag2), mapPartitions2.filter(as.isNotNull()).select(as), false);
        }
    }

    static <T> ScalaInterop.Fun1<Tuple2<Integer, T>, TraversableOnce<T>> selectByColumnIdx(int i) {
        return tuple2 -> {
            return i == ((Integer) tuple2._1).intValue() ? ScalaInterop.listOf(tuple2._2) : ScalaInterop.emptyList();
        };
    }

    private Map<String, Integer> tagsColumnIndex(Collection<TupleTag<?>> collection) {
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(collection.size());
        Iterator<TupleTag<?>> it = collection.iterator();
        while (it.hasNext()) {
            newHashMapWithExpectedSize.put(it.next().getId(), Integer.valueOf(newHashMapWithExpectedSize.size()));
        }
        return newHashMapWithExpectedSize;
    }

    private List<Encoder<WindowedValue<Object>>> createEncoders(Map<TupleTag<?>, PCollection<?>> map, Map<String, Integer> map2, TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>>.Context context) {
        ArrayList arrayList = new ArrayList(map.size());
        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : map.entrySet()) {
            arrayList.add(((Integer) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(map2.get(entry.getKey().getId()))).intValue(), context.windowedEncoder(entry.getValue().getCoder()));
        }
        return arrayList;
    }

    private <T> SideInputReader createSideInputReader(Collection<PCollectionView<?>> collection, TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>>.Context context) {
        if (collection.isEmpty()) {
            return SparkSideInputReader.empty();
        }
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(collection.size());
        for (PCollectionView<?> pCollectionView : collection) {
            PCollection<T> pCollection = (PCollection) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(pCollectionView.getPCollection());
            newHashMapWithExpectedSize.put(pCollectionView.getTagInternal().getId(), context.getSideInputBroadcast(pCollection, SideInputValues.loader(pCollection)));
        }
        return SparkSideInputReader.create(newHashMapWithExpectedSize);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -396260895:
                if (implMethodName.equals("lambda$selectByColumnIdx$a099185f$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop$Fun1") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch") && serializedLambda.getImplMethodSignature().equals("(ILscala/Tuple2;)Lscala/collection/TraversableOnce;")) {
                    int intValue = ((Integer) serializedLambda.getCapturedArg(0)).intValue();
                    return tuple2 -> {
                        return intValue == ((Integer) tuple2._1).intValue() ? ScalaInterop.listOf(tuple2._2) : ScalaInterop.emptyList();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
