/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql;

import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.linq4j.function.Parameter;
import org.apache.beam.sdk.extensions.sql.BeamSqlDslBase;
import org.apache.beam.sdk.extensions.sql.BeamSqlUdf;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.extensions.sql.impl.ParseException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;

public class BeamSqlDslUdfUdafTest
extends BeamSqlDslBase {
    @Test
    public void testUdaf() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{0, 30}).build();
        String sql1 = "SELECT f_int2, squaresum1(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result1 = (PCollection)this.boundedInput1.apply("testUdaf1", (PTransform)SqlTransform.query((String)sql1).registerUdaf("squaresum1", (Combine.CombineFn)new SquareSum()));
        PAssert.that((PCollection)result1).containsInAnyOrder((Object[])new Row[]{row});
        String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result2 = (PCollection)PCollectionTuple.of((TupleTag)new TupleTag("PCOLLECTION"), (PCollection)this.boundedInput1).apply("testUdaf2", (PTransform)SqlTransform.query((String)sql2).registerUdaf("squaresum2", (Combine.CombineFn)new SquareSum()));
        PAssert.that((PCollection)result2).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testUdafMultiLevelDescendent() {
        Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{0, 354}).build();
        String sql1 = "SELECT f_int2, double_square_sum(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result1 = (PCollection)this.boundedInput1.apply("testUdaf", (PTransform)SqlTransform.query((String)sql1).registerUdaf("double_square_sum", (Combine.CombineFn)new SquareSquareSum()));
        PAssert.that((PCollection)result1).containsInAnyOrder((Object[])new Row[]{row});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testRawCombineFnSubclass() {
        this.exceptions.expect(ParseException.class);
        this.exceptions.expectCause(ThrowableMessageMatcher.hasMessage((Matcher)Matchers.containsString((String)"CombineFn must be parameterized")));
        this.pipeline.enableAbandonedNodeEnforcement(false);
        Schema resultType = Schema.builder().addInt32Field("f_int2").addInt32Field("squaresum").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{0, 354}).build();
        String sql1 = "SELECT f_int2, squaresum(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result1 = (PCollection)this.boundedInput1.apply("testUdaf", (PTransform)SqlTransform.query((String)sql1).registerUdaf("squaresum", (Combine.CombineFn)new RawCombineFn()));
    }

    @Test
    public void testUdf() throws Exception {
        Schema resultType = Schema.builder().addInt32Field("f_int").addInt32Field("cubicvalue").build();
        Row row = Row.withSchema((Schema)resultType).addValues(new Object[]{2, 8}).build();
        String sql1 = "SELECT f_int, cubic1(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
        PCollection result1 = (PCollection)this.boundedInput1.apply("testUdf1", (PTransform)SqlTransform.query((String)sql1).registerUdf("cubic1", CubicInteger.class));
        PAssert.that((PCollection)result1).containsInAnyOrder((Object[])new Row[]{row});
        String sql2 = "SELECT f_int, cubic2(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
        PCollection result2 = (PCollection)PCollectionTuple.of((TupleTag)new TupleTag("PCOLLECTION"), (PCollection)this.boundedInput1).apply("testUdf2", (PTransform)SqlTransform.query((String)sql2).registerUdf("cubic2", (SerializableFunction)new CubicIntegerFn()));
        PAssert.that((PCollection)result2).containsInAnyOrder((Object[])new Row[]{row});
        String sql3 = "SELECT f_int, substr(f_string) as sub_string FROM PCOLLECTION WHERE f_int = 2";
        PCollection result3 = (PCollection)PCollectionTuple.of((TupleTag)new TupleTag("PCOLLECTION"), (PCollection)this.boundedInput1).apply("testUdf3", (PTransform)SqlTransform.query((String)sql3).registerUdf("substr", UdfFnWithDefault.class));
        Schema subStrSchema = Schema.builder().addInt32Field("f_int").addStringField("sub_string").build();
        Row subStrRow = Row.withSchema((Schema)subStrSchema).addValues(new Object[]{2, "s"}).build();
        PAssert.that((PCollection)result3).containsInAnyOrder((Object[])new Row[]{subStrRow});
        this.pipeline.run().waitUntilFinish();
    }

    public static final class UdfFnWithDefault
    implements BeamSqlUdf {
        public static String eval(@Parameter(name="s") String s, @Parameter(name="n", optional=true) Integer n) {
            return s.substring(0, n == null ? 1 : n);
        }
    }

    public static class CubicIntegerFn
    implements SerializableFunction<Integer, Integer> {
        public Integer apply(Integer input) {
            return input * input * input;
        }
    }

    public static class CubicInteger
    implements BeamSqlUdf {
        public static Integer eval(Integer input) {
            return input * input * input;
        }
    }

    public static class SquareSquareSum
    extends SquareSum {
        @Override
        public Integer addInput(Integer accumulator, Integer input) {
            return super.addInput(accumulator, input * input);
        }
    }

    public static class RawCombineFn
    extends Combine.CombineFn {
        public Object createAccumulator() {
            return null;
        }

        public Object addInput(Object accumulator, Object input) {
            return null;
        }

        public Object mergeAccumulators(Iterable accumulators) {
            return null;
        }

        public Object extractOutput(Object accumulator) {
            return null;
        }
    }

    public static class SquareSum
    extends Combine.CombineFn<Integer, Integer, Integer> {
        public Integer createAccumulator() {
            return 0;
        }

        public Integer addInput(Integer accumulator, Integer input) {
            return accumulator + input * input;
        }

        public Integer mergeAccumulators(Iterable<Integer> accumulators) {
            int v = 0;
            for (Integer accumulator : accumulators) {
                v += accumulator.intValue();
            }
            return v;
        }

        public Integer extractOutput(Integer accumulator) {
            return accumulator;
        }
    }
}

