/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType;
import org.apache.beam.sdk.extensions.sql.BeamSql;
import org.apache.beam.sdk.extensions.sql.BeamSqlDslBase;
import org.apache.beam.sdk.extensions.sql.BeamSqlUdf;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.BeamRecord;
import org.apache.beam.sdk.values.BeamRecordType;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.junit.Test;

public class BeamSqlDslUdfUdafTest
extends BeamSqlDslBase {
    @Test
    public void testUdaf() throws Exception {
        BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int2", "squaresum"), Arrays.asList(4, 4));
        BeamRecord record = new BeamRecord((BeamRecordType)resultType, new Object[]{0, 30});
        String sql1 = "SELECT f_int2, squaresum1(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result1 = (PCollection)this.boundedInput1.apply("testUdaf1", (PTransform)BeamSql.query((String)sql1).withUdaf("squaresum1", (Combine.CombineFn)new SquareSum()));
        PAssert.that((PCollection)result1).containsInAnyOrder((Object[])new BeamRecord[]{record});
        String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum` FROM PCOLLECTION GROUP BY f_int2";
        PCollection result2 = (PCollection)PCollectionTuple.of((TupleTag)new TupleTag("PCOLLECTION"), (PCollection)this.boundedInput1).apply("testUdaf2", (PTransform)BeamSql.queryMulti((String)sql2).withUdaf("squaresum2", (Combine.CombineFn)new SquareSum()));
        PAssert.that((PCollection)result2).containsInAnyOrder((Object[])new BeamRecord[]{record});
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void testUdf() throws Exception {
        BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int", "cubicvalue"), Arrays.asList(4, 4));
        BeamRecord record = new BeamRecord((BeamRecordType)resultType, new Object[]{2, 8});
        String sql1 = "SELECT f_int, cubic1(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
        PCollection result1 = (PCollection)this.boundedInput1.apply("testUdf1", (PTransform)BeamSql.query((String)sql1).withUdf("cubic1", CubicInteger.class));
        PAssert.that((PCollection)result1).containsInAnyOrder((Object[])new BeamRecord[]{record});
        String sql2 = "SELECT f_int, cubic2(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
        PCollection result2 = (PCollection)PCollectionTuple.of((TupleTag)new TupleTag("PCOLLECTION"), (PCollection)this.boundedInput1).apply("testUdf2", (PTransform)BeamSql.queryMulti((String)sql2).withUdf("cubic2", (SerializableFunction)new CubicIntegerFn()));
        PAssert.that((PCollection)result2).containsInAnyOrder((Object[])new BeamRecord[]{record});
        this.pipeline.run().waitUntilFinish();
    }

    public static class CubicIntegerFn
    implements SerializableFunction<Integer, Integer> {
        public Integer apply(Integer input) {
            return input * input * input;
        }
    }

    public static class CubicInteger
    implements BeamSqlUdf {
        public static Integer eval(Integer input) {
            return input * input * input;
        }
    }

    public static class SquareSum
    extends Combine.CombineFn<Integer, Integer, Integer> {
        public Integer createAccumulator() {
            return 0;
        }

        public Integer addInput(Integer accumulator, Integer input) {
            return accumulator + input * input;
        }

        public Integer mergeAccumulators(Iterable<Integer> accumulators) {
            int v = 0;
            Iterator<Integer> ite = accumulators.iterator();
            while (ite.hasNext()) {
                v += ite.next().intValue();
            }
            return v;
        }

        public Integer extractOutput(Integer accumulator) {
            return accumulator;
        }
    }
}

