package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.catalyst.SerializerBuildHelper;
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.expressions.BoundReference;
import org.apache.spark.sql.catalyst.expressions.Coalesce;
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GetStructField;
import org.apache.spark.sql.catalyst.expressions.If;
import org.apache.spark.sql.catalyst.expressions.IsNotNull;
import org.apache.spark.sql.catalyst.expressions.IsNull;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.catalyst.expressions.MapKeys;
import org.apache.spark.sql.catalyst.expressions.MapValues;
import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.ObjectType;
import org.apache.spark.util.MutablePair;
import org.joda.time.Instant;
import scala.Option;
import scala.Some;
import scala.Tuple2;
import scala.collection.IndexedSeq;
import scala.collection.JavaConverters;
import scala.collection.Seq;

/* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.class */
public class EncoderHelpers {
    private static final DataType OBJECT_TYPE = new ObjectType(Object.class);
    private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class);
    private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class);
    private static final DataType KV_TYPE = new ObjectType(KV.class);
    private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class);
    private static final DataType LIST_TYPE = new ObjectType(List.class);
    private static final Set<Class<?>> PRIMITIV_TYPES = ImmutableSet.of(Boolean.class, Byte.class, Short.class, Integer.class, Long.class, Float.class, new Class[]{Double.class});
    private static final Map<Class<?>, Encoder<?>> DEFAULT_ENCODERS = new ConcurrentHashMap();
    private static final Function<Class<?>, Encoder<?>> ENCODER_FACTORY = cls -> {
        if (cls.equals(PaneInfo.class)) {
            return paneInfoEncoder();
        }
        if (cls.equals(GlobalWindow.class)) {
            return binaryEncoder(GlobalWindow.Coder.INSTANCE, false);
        }
        if (cls.equals(IntervalWindow.class)) {
            return binaryEncoder(IntervalWindow.IntervalWindowCoder.of(), false);
        }
        if (cls.equals(Instant.class)) {
            return instantEncoder();
        }
        if (cls.equals(String.class)) {
            return Encoders.STRING();
        }
        if (cls.equals(Boolean.class)) {
            return Encoders.BOOLEAN();
        }
        if (cls.equals(Integer.class)) {
            return Encoders.INT();
        }
        if (cls.equals(Long.class)) {
            return Encoders.LONG();
        }
        if (cls.equals(Float.class)) {
            return Encoders.FLOAT();
        }
        if (cls.equals(Double.class)) {
            return Encoders.DOUBLE();
        }
        if (cls.equals(BigDecimal.class)) {
            return Encoders.DECIMAL();
        }
        if (cls.equals(byte[].class)) {
            return Encoders.BINARY();
        }
        if (cls.equals(Byte.class)) {
            return Encoders.BYTE();
        }
        if (cls.equals(Short.class)) {
            return Encoders.SHORT();
        }
        return null;
    };

    /* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers$Utils.class */
    public static class Utils {
        public static PaneInfo paneInfoFromBytes(byte[] bArr) {
            return (PaneInfo) CoderHelpers.fromByteArray(bArr, PaneInfo.PaneInfoCoder.of());
        }

        public static byte[] paneInfoToBytes(PaneInfo paneInfo) {
            return CoderHelpers.toByteArray(paneInfo, PaneInfo.PaneInfoCoder.of());
        }

        public static Instant maxTimestamp(Iterable<BoundedWindow> iterable) {
            return ((BoundedWindow) Iterables.getOnlyElement(iterable)).maxTimestamp();
        }

        public static List<Object> copyToList(ArrayData arrayData, DataType dataType) {
            return Arrays.asList(arrayData.toObjectArray(dataType));
        }

        public static Seq<Object> toSeq(ArrayData arrayData) {
            return arrayData.toSeq(EncoderHelpers.OBJECT_TYPE);
        }

        public static Seq<Object> toSeq(Collection<Object> collection) {
            return collection instanceof List ? JavaConverters.asScalaBuffer((List) collection) : JavaConverters.collectionAsScalaIterable(collection).toSeq();
        }

        public static TreeMap<Object, Object> toTreeMap(ArrayData arrayData, ArrayData arrayData2, DataType dataType, DataType dataType2) {
            return (TreeMap) toMap(new TreeMap(), arrayData, arrayData2, dataType, dataType2);
        }

        public static HashMap<Object, Object> toMap(ArrayData arrayData, ArrayData arrayData2, DataType dataType, DataType dataType2) {
            return (HashMap) toMap(Maps.newHashMapWithExpectedSize(arrayData.numElements()), arrayData, arrayData2, dataType, dataType2);
        }

        private static <MapT extends Map<Object, Object>> MapT toMap(MapT mapt, ArrayData arrayData, ArrayData arrayData2, DataType dataType, DataType dataType2) {
            IndexedSeq seq = arrayData.toSeq(dataType);
            IndexedSeq seq2 = arrayData2.toSeq(dataType2);
            for (int i = 0; i < seq.size(); i++) {
                mapt.put(seq.apply(i), seq2.apply(i));
            }
            return mapt;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <T> Encoder<T> getOrCreateDefaultEncoder(Class<? super T> cls) {
        return DEFAULT_ENCODERS.computeIfAbsent(cls, ENCODER_FACTORY);
    }

    public static <T> Encoder<T> encoderOf(Class<? super T> cls) {
        Encoder<T> orCreateDefaultEncoder = getOrCreateDefaultEncoder(cls);
        if (orCreateDefaultEncoder == null) {
            throw new IllegalArgumentException("No default coder available for class " + cls);
        }
        return orCreateDefaultEncoder;
    }

    public static <T> Encoder<T> encoderFor(Coder<T> coder) {
        Encoder<T> orCreateDefaultEncoder = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType());
        return orCreateDefaultEncoder != null ? orCreateDefaultEncoder : binaryEncoder(coder, true);
    }

    public static <T, W extends BoundedWindow> Encoder<WindowedValue<T>> windowedValueEncoder(Encoder<T> encoder, Encoder<W> encoder2) {
        Encoder encoderOf = encoderOf(Instant.class);
        Encoder encoderOf2 = encoderOf(PaneInfo.class);
        Encoder collectionEncoder = collectionEncoder(encoder2);
        Expression serializeWindowedValue = serializeWindowedValue(rootRef(WINDOWED_VALUE, true), encoder, encoderOf, collectionEncoder, encoderOf2);
        return EncoderFactory.create(serializeWindowedValue, deserializeWindowedValue(rootCol(serializeWindowedValue.dataType()), encoder, encoderOf, collectionEncoder, encoderOf2), WindowedValue.class);
    }

    public static <T> Encoder<Tuple2<Integer, T>> oneOfEncoder(List<Encoder<T>> list) {
        Expression serializeOneOf = serializeOneOf(rootRef(TUPLE2_TYPE, true), list);
        return EncoderFactory.create(serializeOneOf, deserializeOneOf(rootCol(serializeOneOf.dataType()), list), Tuple2.class);
    }

    public static <K, V> Encoder<KV<K, V>> kvEncoder(Encoder<K> encoder, Encoder<V> encoder2) {
        Expression serializeKV = serializeKV(rootRef(KV_TYPE, true), encoder, encoder2);
        return EncoderFactory.create(serializeKV, deserializeKV(rootCol(serializeKV.dataType()), encoder, encoder2), KV.class);
    }

    public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> encoder) {
        return collectionEncoder(encoder, true);
    }

    public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> encoder, boolean z) {
        Expression serializeSeq = serializeSeq(rootRef(new ObjectType(Collection.class), true), encoder, z);
        return EncoderFactory.create(serializeSeq, deserializeSeq(rootCol(serializeSeq.dataType()), encoder, z, true), Collection.class);
    }

    public static <MapT extends Map<K, V>, K, V> Encoder<MapT> mapEncoder(Encoder<K> encoder, Encoder<V> encoder2, Class<MapT> cls) {
        Expression mapSerializer = mapSerializer(rootRef(new ObjectType(cls), true), encoder, encoder2);
        return EncoderFactory.create(mapSerializer, mapDeserializer(rootCol(mapSerializer.dataType()), encoder, encoder2, cls), cls);
    }

    public static <T1, T2> Encoder<MutablePair<T1, T2>> mutablePairEncoder(Encoder<T1> encoder, Encoder<T2> encoder2) {
        Expression serializeMutablePair = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), encoder, encoder2);
        return EncoderFactory.create(serializeMutablePair, deserializeMutablePair(rootCol(serializeMutablePair.dataType()), encoder, encoder2), MutablePair.class);
    }

    private static Encoder<PaneInfo> paneInfoEncoder() {
        ObjectType objectType = new ObjectType(PaneInfo.class);
        return EncoderFactory.create(EncoderFactory.invokeIfNotNull(Utils.class, "paneInfoToBytes", DataTypes.BinaryType, rootRef(objectType, false)), EncoderFactory.invokeIfNotNull(Utils.class, "paneInfoFromBytes", objectType, rootCol(DataTypes.BinaryType)), PaneInfo.class);
    }

    private static Encoder<Instant> instantEncoder() {
        ObjectType objectType = new ObjectType(Instant.class);
        Expression rootRef = rootRef(objectType, true);
        Expression rootCol = rootCol(DataTypes.LongType);
        return EncoderFactory.create(nullSafe(rootRef, EncoderFactory.invoke(rootRef, "getMillis", DataTypes.LongType, false, new Expression[0])), nullSafe(rootCol, EncoderFactory.invoke(Instant.class, "ofEpochMilli", objectType, rootCol)), Instant.class);
    }

    private static <T> Encoder<T> binaryEncoder(Coder<T> coder, boolean z) {
        Expression lit = lit(coder, (Class<? extends Coder<T>>) Coder.class);
        return EncoderFactory.create(EncoderFactory.invokeIfNotNull(CoderHelpers.class, "toByteArray", DataTypes.BinaryType, rootRef(OBJECT_TYPE, z), lit), EncoderFactory.invokeIfNotNull(CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(DataTypes.BinaryType), lit), coder.getEncodedTypeDescriptor().getRawType());
    }

    private static <T, W extends BoundedWindow> Expression serializeWindowedValue(Expression expression, Encoder<T> encoder, Encoder<Instant> encoder2, Encoder<Collection<W>> encoder3, Encoder<PaneInfo> encoder4) {
        return serializerObject(expression, ScalaInterop.tuple("value", serializeField(expression, encoder, "getValue")), ScalaInterop.tuple("timestamp", serializeField(expression, encoder2, "getTimestamp")), ScalaInterop.tuple("windows", serializeField(expression, encoder3, "getWindows")), ScalaInterop.tuple("pane", serializeField(expression, encoder4, "getPane")));
    }

    private static Expression serializerObject(Expression expression, Tuple2<String, Expression>... tuple2Arr) {
        return SerializerBuildHelper.createSerializerForObject(expression, ScalaInterop.seqOf(tuple2Arr));
    }

    private static <T, W extends BoundedWindow> Expression deserializeWindowedValue(Expression expression, Encoder<T> encoder, Encoder<Instant> encoder2, Encoder<Collection<W>> encoder3, Encoder<PaneInfo> encoder4) {
        Expression deserializeField = deserializeField(expression, encoder, 0, "value");
        Expression deserializeField2 = deserializeField(expression, encoder3, 2, "windows");
        Expression deserializeField3 = deserializeField(expression, encoder2, 1, "timestamp");
        Expression deserializeField4 = deserializeField(expression, encoder4, 3, "pane");
        return nullSafe(deserializeField4, EncoderFactory.invoke(WindowedValue.class, "of", WINDOWED_VALUE, deserializeField, ifNotNull(deserializeField3, EncoderFactory.invoke(Utils.class, "maxTimestamp", deserializeField3.dataType(), deserializeField2)), deserializeField2, deserializeField4));
    }

    private static <K, V> Expression serializeMutablePair(Expression expression, Encoder<K> encoder, Encoder<V> encoder2) {
        return serializerObject(expression, ScalaInterop.tuple("_1", serializeField(expression, encoder, "_1")), ScalaInterop.tuple("_2", serializeField(expression, encoder2, "_2")));
    }

    private static <K, V> Expression deserializeMutablePair(Expression expression, Encoder<K> encoder, Encoder<V> encoder2) {
        return EncoderFactory.invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, deserializeField(expression, encoder, 0, "_1"), deserializeField(expression, encoder2, 1, "_2"));
    }

    private static <K, V> Expression serializeKV(Expression expression, Encoder<K> encoder, Encoder<V> encoder2) {
        return serializerObject(expression, ScalaInterop.tuple("key", serializeField(expression, encoder, "getKey")), ScalaInterop.tuple("value", serializeField(expression, encoder2, "getValue")));
    }

    private static <K, V> Expression deserializeKV(Expression expression, Encoder<K> encoder, Encoder<V> encoder2) {
        return EncoderFactory.invoke(KV.class, "of", KV_TYPE, deserializeField(expression, encoder, 0, "key"), deserializeField(expression, encoder2, 1, "value"));
    }

    public static <T> Expression serializeOneOf(Expression expression, List<Encoder<T>> list) {
        Expression invoke = EncoderFactory.invoke(expression, "_1", DataTypes.IntegerType, false, new Expression[0]);
        Expression[] expressionArr = new Expression[list.size() * 2];
        for (int i = 0; i < list.size(); i++) {
            expressionArr[i * 2] = lit(String.valueOf(i));
            expressionArr[(i * 2) + 1] = serializeOneOfField(expression, invoke, list.get(i), i);
        }
        return new CreateNamedStruct(ScalaInterop.seqOf(expressionArr));
    }

    public static <T> Expression deserializeOneOf(Expression expression, List<Encoder<T>> list) {
        Expression[] expressionArr = new Expression[list.size()];
        for (int i = 0; i < list.size(); i++) {
            expressionArr[i] = deserializeOneOfField(expression, list.get(i), i);
        }
        return new Coalesce(ScalaInterop.seqOf(expressionArr));
    }

    private static <T> Expression serializeOneOfField(Expression expression, Expression expression2, Encoder<T> encoder, int i) {
        return new If(new EqualTo(expression2, lit(Integer.valueOf(i))), serialize(EncoderFactory.invoke(expression, "_2", deserializedType(encoder), false, new Expression[0]), encoder), lit((Object) null, serializedType(encoder)));
    }

    private static <T> Expression deserializeOneOfField(Expression expression, Encoder<T> encoder, int i) {
        GetStructField getStructField = new GetStructField(expression, i, Option.empty());
        return new If(new IsNull(getStructField), lit((Object) null, TUPLE2_TYPE), EncoderFactory.newInstance(Tuple2.class, TUPLE2_TYPE, lit(Integer.valueOf(i)), deserialize(getStructField, encoder)));
    }

    private static <T> Expression serializeField(Expression expression, Encoder<T> encoder, String str) {
        Expression expression2 = (Expression) serializer(encoder).collect(ScalaInterop.match(BoundReference.class)).head();
        return serialize(EncoderFactory.invoke(expression, str, expression2.dataType(), expression2.nullable(), new Expression[0]), encoder);
    }

    private static <T> Expression deserializeField(Expression expression, Encoder<T> encoder, int i, String str) {
        return deserialize(new GetStructField(expression, i, new Some(str)), encoder);
    }

    private static <K, V> Expression mapSerializer(Expression expression, Encoder<K> encoder, Encoder<V> encoder2) {
        return SerializerBuildHelper.createSerializerForMap(expression, new SerializerBuildHelper.MapElementInformation(deserializedType(encoder), false, expression2 -> {
            return serialize(expression2, encoder);
        }), new SerializerBuildHelper.MapElementInformation(deserializedType(encoder2), false, expression3 -> {
            return serialize(expression3, encoder2);
        }));
    }

    private static <MapT extends Map<K, V>, K, V> Expression mapDeserializer(Expression expression, Encoder<K> encoder, Encoder<V> encoder2, Class<MapT> cls) {
        Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || cls.equals(TreeMap.class));
        return EncoderFactory.invoke(Utils.class, cls.equals(TreeMap.class) ? "toTreeMap" : "toMap", new ObjectType(cls), deserializeSeq(new MapKeys(expression), encoder, false, false), deserializeSeq(new MapValues(expression), encoder2, false, false), mapItemType(encoder), mapItemType(encoder2));
    }

    private static Literal mapItemType(Encoder<?> encoder) {
        return lit(isPrimitiveEnc(encoder) ? serializedType(encoder) : deserializedType(encoder), (Class<? extends DataType>) DataType.class);
    }

    private static <T> Expression serializeSeq(Expression expression, Encoder<T> encoder, boolean z) {
        return isPrimitiveEnc(encoder) ? SerializerBuildHelper.createSerializerForGenericArray(EncoderFactory.invoke(expression, "toArray", (DataType) new ObjectType(Object[].class), false, new Expression[0]), serializedType(encoder), z) : MapObjects$.MODULE$.apply(expression2 -> {
            return serialize(expression2, encoder);
        }, EncoderFactory.invoke(Utils.class, "toSeq", new ObjectType(Seq.class), expression), deserializedType(encoder), z, Option.empty());
    }

    private static <T> Expression deserializeSeq(Expression expression, Encoder<T> encoder, boolean z, boolean z2) {
        DataType serializedType = serializedType(encoder);
        if (isPrimitiveEnc(encoder)) {
            return z2 ? EncoderFactory.invoke(Utils.class, "copyToList", LIST_TYPE, expression, lit(serializedType, (Class<? extends DataType>) DataType.class)) : expression;
        }
        return MapObjects$.MODULE$.apply(expression2 -> {
            return deserialize(expression2, encoder);
        }, expression, serializedType, z, z2 ? Option.apply(List.class) : Option.empty());
    }

    private static <T> boolean isPrimitiveEnc(Encoder<T> encoder) {
        return PRIMITIV_TYPES.contains(encoder.clsTag().runtimeClass());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <T> Expression serialize(Expression expression, Encoder<T> encoder) {
        return serializer(encoder).transformUp(ScalaInterop.replace(BoundReference.class, expression));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <T> Expression deserialize(Expression expression, Encoder<T> encoder) {
        return deserializer(encoder).transformUp(ScalaInterop.replace(GetColumnByOrdinal.class, expression));
    }

    private static <T> Expression serializer(Encoder<T> encoder) {
        return ((ExpressionEncoder) encoder).objSerializer();
    }

    private static <T> Expression deserializer(Encoder<T> encoder) {
        return ((ExpressionEncoder) encoder).objDeserializer();
    }

    private static <T> DataType serializedType(Encoder<T> encoder) {
        return ((ExpressionEncoder) encoder).objSerializer().dataType();
    }

    private static <T> DataType deserializedType(Encoder<T> encoder) {
        return ((ExpressionEncoder) encoder).objDeserializer().dataType();
    }

    private static Expression rootRef(DataType dataType, boolean z) {
        return new BoundReference(0, dataType, z);
    }

    private static Expression rootCol(DataType dataType) {
        return new GetColumnByOrdinal(0, dataType);
    }

    private static Expression nullSafe(Expression expression, Expression expression2) {
        return new If(new IsNull(expression), lit((Object) null, expression2.dataType()), expression2);
    }

    private static Expression ifNotNull(Expression expression, Expression expression2) {
        return new If(new IsNotNull(expression), expression, expression2);
    }

    private static <T> Expression lit(T t) {
        return Literal$.MODULE$.apply(t);
    }

    private static <T> Expression lit(T t, DataType dataType) {
        return new Literal(t, dataType);
    }

    private static <T> Literal lit(T t, Class<? extends T> cls) {
        return Literal.fromObject(t, new ObjectType(cls));
    }
}
