/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql;

import com.google.auto.service.AutoService;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.time.LocalDate;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.beam.sdk.extensions.sql.BeamSqlDslBase;
import org.apache.beam.sdk.extensions.sql.BeamSqlUdf;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.extensions.sql.impl.BeamCalciteTable;
import org.apache.beam.sdk.extensions.sql.impl.ParseException;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
import org.apache.beam.sdk.extensions.sql.meta.provider.UdfUdafProvider;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable;
import org.apache.beam.sdk.extensions.sql.utils.DateTimeUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.calcite.v1_28_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.linq4j.function.Parameter;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.schema.TranslatableTable;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.joda.time.Instant;
import org.joda.time.ReadableInstant;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;

public class BeamSqlDslUdfUdafTest
extends BeamSqlDslBase {
    @Test
    public void testUdaf() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{0, 30}).build();
        String sql = "SELECT f_int2, squaresum(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result = (PCollection)this.boundedInput1.apply("testUdaf", (PTransform)SqlTransform.query((String)sql).registerUdaf("squaresum", (Combine.CombineFn)new SquareSum()));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testTimestampUdaf() throws Exception {
        Schema resultType = Schema.builder().addDateTimeField("jodatime").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{DateTimeUtils.parseTimestampWithoutTimeZone("2017-01-01 02:04:03")}).build();
        String sql = "SELECT MAX_JODA(f_timestamp) as jodatime FROM PCOLLECTION";
        PCollection result = (PCollection)this.boundedInput1.apply("testJodaUdaf", (PTransform)SqlTransform.query((String)sql).registerUdaf("MAX_JODA", (Combine.CombineFn)new JodaMax()));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testDateUdf() throws Exception {
        Schema resultType = Schema.builder().addField("result_date", Schema.FieldType.logicalType((Schema.LogicalType)SqlTypes.DATE)).build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{LocalDate.of(2016, 12, 31)}).build();
        String sql = "SELECT PRE_DATE(f_date) as result_date FROM PCOLLECTION WHERE f_int=1";
        PCollection result = (PCollection)this.boundedInput1.apply("testTimeUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("PRE_DATE", PreviousDate.class));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testTimeUdf() throws Exception {
        Schema resultType = Schema.builder().addField("result_time", Schema.FieldType.logicalType((Schema.LogicalType)SqlTypes.TIME)).build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{LocalTime.of(0, 1, 3)}).build();
        String sql = "SELECT PRE_HOUR(f_time) as result_time FROM PCOLLECTION WHERE f_int=1";
        PCollection result = (PCollection)this.boundedInput1.apply("testTimeUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("PRE_HOUR", PreviousHour.class));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testTimestampUdf() throws Exception {
        Schema resultType = Schema.builder().addDateTimeField("result_time").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{DateTimeUtils.parseTimestampWithoutTimeZone("2016-12-31 01:01:03")}).build();
        String sql = "SELECT PRE_DAY(f_timestamp) as result_time FROM PCOLLECTION WHERE f_int=1";
        PCollection result = (PCollection)this.boundedInput1.apply("testTimeUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("PRE_DAY", PreviousDay.class));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testUdafWithMapOutput() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int2").addMapField("squareAndAccumulateInMap", Schema.FieldType.STRING, Schema.FieldType.INT32).build();
        HashMap<String, Integer> resultMap = new HashMap<String, Integer>();
        resultMap.put("squareOf-1", 1);
        resultMap.put("squareOf-2", 4);
        resultMap.put("squareOf-3", 9);
        resultMap.put("squareOf-4", 16);
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{0, resultMap}).build();
        String sql = "SELECT f_int2,squareAndAccumulateInMap(f_int) AS `squareAndAccumulateInMap` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result = (PCollection)this.boundedInput1.apply("testUdafWithMapOutput", (PTransform)SqlTransform.query((String)sql).registerUdaf("squareAndAccumulateInMap", (Combine.CombineFn)new SquareAndAccumulateInMap()));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testUdafWithListOutput() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int2").addArrayField("squareAndAccumulateInList", Schema.FieldType.INT32).build();
        Row row = Row.withSchema((Schema)resultType).addValue((Object)0).addArray(Arrays.asList(1, 4, 9, 16)).build();
        String sql = "SELECT f_int2,squareAndAccumulateInList(f_int) AS `squareAndAccumulateInList` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result = (PCollection)this.boundedInput1.apply("testUdafWithListOutput", (PTransform)SqlTransform.query((String)sql).registerUdaf("squareAndAccumulateInList", (Combine.CombineFn)new SquareAndAccumulateInList()));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testUdfWithListOutput() throws Exception {
        Schema resultType = Schema.builder().addArrayField("array_field", Schema.FieldType.INT64).build();
        Row row = Row.withSchema((Schema)resultType).addValue(Arrays.asList(1L)).build();
        String sql = "SELECT test_array(1)";
        PCollection result = (PCollection)this.boundedInput1.apply("testArrayUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("test_array", TestReturnTypeList.class));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testUdfWithListInput() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("int_field").build();
        Row row = Row.withSchema((Schema)resultType).addValue((Object)3).build();
        String sql = "select array_length(ARRAY[1, 2, 3])";
        PCollection result = (PCollection)this.boundedInput1.apply("testArrayUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("array_length", TestListLength.class));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testUdafMultiLevelDescendent() {
        Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{0, 354}).build();
        String sql = "SELECT f_int2, double_square_sum(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result = (PCollection)this.boundedInput1.apply("testUdaf", (PTransform)SqlTransform.query((String)sql).registerUdaf("double_square_sum", (Combine.CombineFn)new SquareSquareSum()));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testRawCombineFnSubclass() {
        this.exceptions.expect(ParseException.class);
        this.exceptions.expectCause(ThrowableMessageMatcher.hasMessage((Matcher)Matchers.containsString((String)"CombineFn must be parameterized")));
        this.pipeline.enableAbandonedNodeEnforcement(false);
        String sql = "SELECT f_int2, squaresum(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        this.boundedInput1.apply("testUdaf", (PTransform)SqlTransform.query((String)sql).registerUdaf("squaresum", (Combine.CombineFn)new RawCombineFn()));
    }

    @Test
    public void testBeamSqlUdf() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int").addInt32Field("cubicvalue").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{2, 8}).build();
        String sql = "SELECT f_int, cubic(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
        PCollection result = (PCollection)this.boundedInput1.apply("testUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("cubic", CubicInteger.class));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testSerializableFunctionUdf() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int").addInt32Field("cubicvalue").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{2, 8}).build();
        String sql = "SELECT f_int, cubic(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
        PCollection result = (PCollection)PCollectionTuple.of((TupleTag)new TupleTag("PCOLLECTION"), (PCollection)this.boundedInput1).apply("testUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("cubic", (SerializableFunction)new CubicIntegerFn()));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testBeamSqlUdfWithDefaultParameters() throws Exception {
        String sql = "SELECT f_int, substr(f_string) as sub_string FROM PCOLLECTION WHERE f_int = 2";
        PCollection result = (PCollection)PCollectionTuple.of((TupleTag)new TupleTag("PCOLLECTION"), (PCollection)this.boundedInput1).apply("testUdf", (PTransform)SqlTransform.query((String)sql).registerUdf("substr", UdfFnWithDefault.class));
        Schema subStrSchema = Schema.builder().addInt32Field("f_int").addStringField("sub_string").build();
        Row subStrRow = Row.withSchema((Schema)subStrSchema).addValues(new Object[]{2, "s"}).build();
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{subStrRow});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testTableMacroUdf() throws Exception {
        String sql = "SELECT * FROM table(range_udf(0, 3))";
        Schema schema = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"f0", (Schema.FieldType)Schema.FieldType.INT32)});
        PCollection rows = (PCollection)this.pipeline.apply((PTransform)SqlTransform.query((String)sql).registerUdf("range_udf", RangeUdf.class));
        PAssert.that((PCollection)rows).containsInAnyOrder((Object[])new Row[]{Row.withSchema((Schema)schema).addValue((Object)0).build(), Row.withSchema((Schema)schema).addValue((Object)1).build(), Row.withSchema((Schema)schema).addValue((Object)2).build()});
        this.pipeline.run();
    }

    @Test
    public void testAutoLoadedUdfUdaf() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("autoload_squarecubicsum").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{0, 4890}).build();
        String sql = "SELECT f_int2, autoload_squaresum(autoload_cubic(f_int)) AS `autoload_squarecubicsum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result = (PCollection)this.boundedInput1.apply("testUdaf", (PTransform)SqlTransform.query((String)sql));
        PAssert.that((PCollection)result).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    public static class SquareAndAccumulateInList
    extends Combine.CombineFn<Integer, List<Integer>, List<Integer>> {
        public List<Integer> createAccumulator() {
            return new ArrayList<Integer>();
        }

        public List<Integer> addInput(List<Integer> accumulator, Integer input) {
            accumulator.add(input * input);
            return accumulator;
        }

        public List<Integer> mergeAccumulators(Iterable<List<Integer>> accumulators) {
            Object merged = this.createAccumulator();
            for (List<Integer> accumulator : accumulators) {
                merged.addAll(accumulator);
            }
            return merged;
        }

        public List<Integer> extractOutput(List<Integer> accumulator) {
            Collections.sort(accumulator);
            return accumulator;
        }
    }

    public static class SquareAndAccumulateInMap
    extends Combine.CombineFn<Integer, Map<String, Integer>, Map<String, Integer>> {
        public Map<String, Integer> createAccumulator() {
            return new HashMap<String, Integer>();
        }

        public Map<String, Integer> addInput(Map<String, Integer> accumulator, Integer input) {
            accumulator.put("squareOf-" + input, input * input);
            return accumulator;
        }

        public Map<String, Integer> mergeAccumulators(Iterable<Map<String, Integer>> accumulators) {
            Object merged = this.createAccumulator();
            for (Map<String, Integer> accumulator : accumulators) {
                merged.putAll(accumulator);
            }
            return merged;
        }

        public Map<String, Integer> extractOutput(Map<String, Integer> accumulator) {
            return accumulator;
        }
    }

    public static final class RangeUdf
    implements BeamSqlUdf {
        public static TranslatableTable eval(int startInclusive, int endExclusive) {
            Schema schema = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"f0", (Schema.FieldType)Schema.FieldType.INT32)});
            Object[] values = IntStream.range(startInclusive, endExclusive).boxed().toArray();
            return BeamCalciteTable.of((BeamSqlTable)new TestBoundedTable(schema).addRows(values));
        }
    }

    public static final class TestListLength
    implements BeamSqlUdf {
        public static Integer eval(List<Long> i) {
            return i.size();
        }
    }

    public static final class TestReturnTypeList
    implements BeamSqlUdf {
        public static List<Long> eval(Long i) {
            return Arrays.asList(i);
        }
    }

    public static final class PreviousDay
    implements BeamSqlUdf {
        public static Timestamp eval(Timestamp time) {
            return new Timestamp(time.getTime() - 86400000L);
        }
    }

    public static final class PreviousHour
    implements BeamSqlUdf {
        public static Time eval(Time time) {
            return new Time(time.getTime() - 3600000L);
        }
    }

    public static final class PreviousDate
    implements BeamSqlUdf {
        public static Date eval(Date date) {
            return new Date(date.getTime() - 86400000L);
        }
    }

    public static final class UdfFnWithDefault
    implements BeamSqlUdf {
        public static String eval(@Parameter(name="s") String s, @Parameter(name="n", optional=true) Integer n) {
            return s.substring(0, n == null ? 1 : n);
        }
    }

    public static class CubicIntegerFn
    implements SerializableFunction<Integer, Integer> {
        public Integer apply(Integer input) {
            return input * input * input;
        }
    }

    public static class CubicInteger
    implements BeamSqlUdf {
        public static Integer eval(Integer input) {
            return input * input * input;
        }
    }

    public static class SquareSquareSum
    extends SquareSum {
        @Override
        public Integer addInput(Integer accumulator, Integer input) {
            return super.addInput(accumulator, input * input);
        }
    }

    public static class RawCombineFn
    extends Combine.CombineFn {
        public Object createAccumulator() {
            return null;
        }

        public Object addInput(Object accumulator, Object input) {
            return null;
        }

        public Object mergeAccumulators(Iterable accumulators) {
            return null;
        }

        public Object extractOutput(Object accumulator) {
            return null;
        }
    }

    public static class JodaMax
    extends Combine.CombineFn<Instant, Instant, Instant> {
        public Instant createAccumulator() {
            return new Instant(0L);
        }

        public Instant addInput(Instant accumulator, Instant input) {
            return accumulator.isBefore((ReadableInstant)input) ? input : accumulator;
        }

        public Instant mergeAccumulators(Iterable<Instant> accumulators) {
            Instant v = new Instant(0L);
            for (Instant accumulator : accumulators) {
                v = accumulator.isBefore((ReadableInstant)v) ? v : accumulator;
            }
            return v;
        }

        public Instant extractOutput(Instant accumulator) {
            return accumulator;
        }
    }

    public static class SquareSum
    extends Combine.CombineFn<Integer, Integer, Integer> {
        public Integer createAccumulator() {
            return 0;
        }

        public Integer addInput(Integer accumulator, Integer input) {
            return accumulator + input * input;
        }

        public Integer mergeAccumulators(Iterable<Integer> accumulators) {
            int v = 0;
            for (Integer accumulator : accumulators) {
                v += accumulator.intValue();
            }
            return v;
        }

        public Integer extractOutput(Integer accumulator) {
            return accumulator;
        }
    }

    @AutoService(value={UdfUdafProvider.class})
    public static class UdfUdafProviderTest
    implements UdfUdafProvider {
        public Map<String, Class<? extends BeamSqlUdf>> getBeamSqlUdfs() {
            return ImmutableMap.of((Object)"autoload_cubic", CubicInteger.class);
        }

        public Map<String, Combine.CombineFn> getUdafs() {
            return ImmutableMap.of((Object)"autoload_squaresum", (Object)((Object)new SquareSum()));
        }
    }
}

