package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;

import java.lang.invoke.SerializedLambda;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
import org.apache.beam.sdk.coders.BigDecimalCoder;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.BigEndianLongCoder;
import org.apache.beam.sdk.coders.BigEndianShortCoder;
import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.ByteCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.DelegateCoder;
import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.FloatCoder;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructField;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.joda.time.Instant;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.class */
public class EncoderHelpersTest {

    @ClassRule
    public static SparkSessionRule sessionRule = new SparkSessionRule("local[1]", (KV<String, String>[]) new KV[0]);
    private static final Encoder<GlobalWindow> windowEnc = EncoderHelpers.encoderOf(GlobalWindow.class);
    private static final Map<Coder<?>, List<?>> BASIC_CASES = ImmutableMap.builder().put(BooleanCoder.of(), Arrays.asList(true, false, null)).put(ByteCoder.of(), Arrays.asList((byte) 1, null)).put(BigEndianShortCoder.of(), Arrays.asList((short) 1, null)).put(BigEndianIntegerCoder.of(), Arrays.asList(1, 2, 3, null)).put(VarIntCoder.of(), Arrays.asList(1, 2, 3, null)).put(BigEndianLongCoder.of(), Arrays.asList(1L, 2L, 3L, null)).put(VarLongCoder.of(), Arrays.asList(1L, 2L, 3L, null)).put(FloatCoder.of(), Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), null)).put(DoubleCoder.of(), Arrays.asList(Double.valueOf(1.0d), Double.valueOf(2.0d), null)).put(StringUtf8Coder.of(), Arrays.asList("1", "2", null)).put(BigDecimalCoder.of(), Arrays.asList(bigDecimalOf(1), bigDecimalOf(2), null)).put(InstantCoder.of(), Arrays.asList(Instant.ofEpochMilli(1), null)).build();

    /* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest$PrivateString.class */
    private static class PrivateString {
        private static final Coder<PrivateString> CODER = DelegateCoder.of(StringUtf8Coder.of(), privateString -> {
            return privateString.string;
        }, PrivateString::new, new TypeDescriptor<PrivateString>() { // from class: org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpersTest.PrivateString.1
        });
        private final String string;

        public PrivateString(String str) {
            this.string = str;
        }

        public boolean equals(Object obj) {
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            return Objects.equals(this.string, ((PrivateString) obj).string);
        }

        public int hashCode() {
            return Objects.hash(this.string);
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case 447839912:
                    if (implMethodName.equals("lambda$static$1bd6f3b4$1")) {
                        z = false;
                        break;
                    }
                    break;
                case 1818100338:
                    if (implMethodName.equals("<init>")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/sdk/coders/DelegateCoder$CodingFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest$PrivateString") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest$PrivateString;)Ljava/lang/String;")) {
                        return privateString -> {
                            return privateString.string;
                        };
                    }
                    break;
                case true:
                    if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/sdk/coders/DelegateCoder$CodingFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest$PrivateString") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;)V")) {
                        return PrivateString::new;
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    private <T> Dataset<T> createDataset(List<T> list, Encoder<T> encoder) {
        Dataset<T> createDataset = sessionRule.getSession().createDataset(list, encoder);
        createDataset.printSchema();
        return createDataset;
    }

    @Test
    public void testBeamEncoderMappings() {
        BASIC_CASES.forEach((coder, list) -> {
            Encoder encoderFor = EncoderHelpers.encoderFor(coder);
            serializeAndDeserialize(list.get(0), encoderFor);
            MatcherAssert.assertThat(createDataset(list, encoderFor).collect(), Matchers.equalTo(list.toArray()));
        });
    }

    @Test
    public void testBeamEncoderOfPrivateType() {
        List asList = Arrays.asList(new PrivateString("1"), new PrivateString("2"));
        MatcherAssert.assertThat(createDataset(asList, EncoderHelpers.encoderFor(PrivateString.CODER)).collect(), Matchers.equalTo(asList.toArray()));
    }

    @Test
    public void testBeamWindowedValueEncoderMappings() {
        BASIC_CASES.forEach((coder, list) -> {
            List transform = Lists.transform(list, WindowedValue::valueInGlobalWindow);
            Encoder windowedValueEncoder = EncoderHelpers.windowedValueEncoder(EncoderHelpers.encoderFor(coder), windowEnc);
            serializeAndDeserialize((WindowedValue) transform.get(0), windowedValueEncoder);
            MatcherAssert.assertThat(createDataset(transform, windowedValueEncoder).collect(), Matchers.equalTo(transform.toArray()));
        });
    }

    @Test
    public void testCollectionEncoder() {
        BASIC_CASES.forEach((coder, list) -> {
            MatcherAssert.assertThat((Collection) createDataset(Arrays.asList(Collections.unmodifiableCollection(list)), EncoderHelpers.collectionEncoder(EncoderHelpers.encoderFor(coder), true)).head(), Matchers.equalTo(list));
        });
    }

    private void testMapEncoder(Class<?> cls, Function<Map<?, ?>, Map<?, ?>> function) {
        BASIC_CASES.forEach((coder, list) -> {
            Encoder encoderFor = EncoderHelpers.encoderFor(coder);
            Encoder mapEncoder = EncoderHelpers.mapEncoder(encoderFor, encoderFor, cls);
            Map map = (Map) function.apply((Map) list.stream().filter(Predicates.notNull()).collect(Collectors.toMap(Function.identity(), Function.identity())));
            Map map2 = (Map) createDataset(Arrays.asList(map), mapEncoder).head();
            MatcherAssert.assertThat(map2, Matchers.equalTo(map));
            MatcherAssert.assertThat(map2, Matchers.instanceOf(cls));
        });
    }

    @Test
    public void testMapEncoder() {
        testMapEncoder(Map.class, Function.identity());
    }

    @Test
    public void testHashMapEncoder() {
        testMapEncoder(HashMap.class, Function.identity());
    }

    @Test
    public void testTreeMapEncoder() {
        testMapEncoder(TreeMap.class, TreeMap::new);
    }

    @Test
    public void testBeamBinaryEncoder() {
        List asList = Arrays.asList(Arrays.asList("a1", "a2", "a3"), Arrays.asList("b1", "b2"), Arrays.asList("c1"));
        Encoder encoderFor = EncoderHelpers.encoderFor(ListCoder.of(StringUtf8Coder.of()));
        serializeAndDeserialize((List) asList.get(0), encoderFor);
        MatcherAssert.assertThat(createDataset(asList, encoderFor).collect(), Matchers.equalTo(asList.toArray()));
    }

    @Test
    public void testEncoderForKVCoder() {
        List asList = Arrays.asList(KV.of(1, "value1"), KV.of((Object) null, "value2"), KV.of(3, (Object) null));
        Encoder kvEncoder = EncoderHelpers.kvEncoder(EncoderHelpers.encoderFor(VarIntCoder.of()), EncoderHelpers.encoderFor(StringUtf8Coder.of()));
        serializeAndDeserialize((KV) asList.get(0), kvEncoder);
        Dataset createDataset = createDataset(asList, kvEncoder);
        MatcherAssert.assertThat(createDataset.schema(), Matchers.equalTo(DataTypes.createStructType(new StructField[]{DataTypes.createStructField("key", DataTypes.IntegerType, true), DataTypes.createStructField("value", DataTypes.StringType, true)})));
        MatcherAssert.assertThat(createDataset.collectAsList(), Matchers.equalTo(asList));
    }

    @Test
    public void testOneOffEncoder() {
        ImmutableList copyOf = ImmutableList.copyOf(BASIC_CASES.keySet());
        List list = (List) copyOf.stream().map(EncoderHelpers::encoderFor).collect(Collectors.toList());
        List list2 = (List) BASIC_CASES.entrySet().stream().map(entry -> {
            return ScalaInterop.tuple(Integer.valueOf(copyOf.indexOf(entry.getKey())), ((List) entry.getValue()).get(0));
        }).collect(Collectors.toList());
        MatcherAssert.assertThat(createDataset(list2, EncoderHelpers.oneOfEncoder(list)).collectAsList(), Matchers.equalTo(list2));
    }

    private static BigDecimal bigDecimalOf(long j) {
        DecimalType SYSTEM_DEFAULT = DecimalType.SYSTEM_DEFAULT();
        return new BigDecimal(j, new MathContext(SYSTEM_DEFAULT.precision())).setScale(SYSTEM_DEFAULT.scale());
    }

    private static <T> void serializeAndDeserialize(T t, Encoder<T> encoder) {
        ExpressionEncoder expressionEncoder = (ExpressionEncoder) encoder;
        ExpressionEncoder resolveAndBind = expressionEncoder.resolveAndBind(expressionEncoder.resolveAndBind$default$1(), expressionEncoder.resolveAndBind$default$2());
        MatcherAssert.assertThat(resolveAndBind.createDeserializer().apply(resolveAndBind.createSerializer().apply(t)), Matchers.equalTo(t));
    }
}
