/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.schema.transform;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.com.google.common.collect.Lists;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.SqlKind;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.type.BasicSqlType;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.Pair;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.extensions.sql.impl.schema.transform.BeamTransformBaseTest;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamAggregationTransforms;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.junit.Rule;
import org.junit.Test;

public class BeamAggregationTransformTest
extends BeamTransformBaseTest {
    @Rule
    public TestPipeline p = TestPipeline.create();
    private List<Pair<AggregateCall, String>> aggCalls;
    private Schema keyType;
    private Schema aggPartType;
    private Schema outputType;
    private Coder<Row> inRecordCoder;
    private Coder<Row> keyCoder;
    private Coder<Row> aggCoder;

    @Test
    public void testCountPerElementBasic() throws ParseException {
        this.setupEnvironment();
        PCollection input = (PCollection)this.p.apply((PTransform)Create.of((Iterable)inputRows));
        Schema keySchema = Schema.builder().addFields(Lists.newArrayList(inputSchema.getField(0))).build();
        PCollection exGroupByStream = ((PCollection)input.apply("exGroupBy", (PTransform)WithKeys.of((SerializableFunction)new BeamAggregationTransforms.AggregationGroupByKeyFn(keySchema, -1, ImmutableBitSet.of(0))))).setCoder((Coder)KvCoder.of(this.keyCoder, this.inRecordCoder));
        PCollection groupedStream = ((PCollection)exGroupByStream.apply("groupBy", (PTransform)GroupByKey.create())).setCoder((Coder)KvCoder.of(this.keyCoder, (Coder)IterableCoder.of(this.inRecordCoder)));
        PCollection aggregatedStream = ((PCollection)groupedStream.apply("aggregation", (PTransform)Combine.groupedValues((CombineFnBase.GlobalCombineFn)new BeamAggregationTransforms.AggregationAdaptor(this.aggCalls, inputSchema)))).setCoder((Coder)KvCoder.of(this.keyCoder, this.aggCoder));
        PCollection mergedStream = (PCollection)aggregatedStream.apply("mergeRecord", (PTransform)ParDo.of((DoFn)new BeamAggregationTransforms.MergeAggregationRecord(this.outputType, -1)));
        mergedStream.setRowSchema(this.outputType);
        PAssert.that((PCollection)exGroupByStream).containsInAnyOrder(this.prepareResultOfAggregationGroupByKeyFn());
        PAssert.that((PCollection)aggregatedStream).containsInAnyOrder(this.prepareResultOfAggregationCombineFn());
        PAssert.that((PCollection)mergedStream).containsInAnyOrder((Object[])new Row[]{this.prepareResultOfMergeAggregationRow()});
        this.p.run();
    }

    private void setupEnvironment() {
        this.prepareAggregationCalls();
        this.prepareTypeAndCoder();
    }

    private void prepareAggregationCalls() {
        this.aggCalls = new ArrayList<Pair<AggregateCall, String>>();
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlCountAggFunction("COUNT"), false, Arrays.asList(new Integer[0]), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "count"), "count"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT)), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "sum1"), "sum1"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "avg1"), "avg1"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "max1"), "max1"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "min1"), "min1"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT)), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "sum2"), "sum2"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "avg2"), "avg2"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "max2"), "max2"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "min2"), "min2"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT)), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "sum3"), "sum3"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "avg3"), "avg3"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "max3"), "max3"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "min3"), "min3"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT)), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "sum4"), "sum4"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "avg4"), "avg4"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "max4"), "max4"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "min4"), "min4"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE)), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "sum5"), "sum5"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "avg5"), "avg5"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "max5"), "max5"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "min5"), "min5"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(7), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), "max7"), "max7"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(7), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), "min7"), "min7"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "sum8"), "sum8"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "avg8"), "avg8"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "max8"), "max8"));
        this.aggCalls.add(Pair.of(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "min8"), "min8"));
    }

    private void prepareTypeAndCoder() {
        this.inRecordCoder = SchemaCoder.of((Schema)inputSchema, (SerializableFunction)SerializableFunctions.identity(), (SerializableFunction)SerializableFunctions.identity());
        this.keyType = Schema.builder().addInt32Field("f_int").build();
        this.keyCoder = SchemaCoder.of((Schema)this.keyType, (SerializableFunction)SerializableFunctions.identity(), (SerializableFunction)SerializableFunctions.identity());
        this.aggPartType = Schema.builder().addInt64Field("count").addInt64Field("sum1").addInt64Field("avg1").addInt64Field("max1").addInt64Field("min1").addInt16Field("sum2").addInt16Field("avg2").addInt16Field("max2").addInt16Field("min2").addByteField("sum3").addByteField("avg3").addByteField("max3").addByteField("min3").addFloatField("sum4").addFloatField("avg4").addFloatField("max4").addFloatField("min4").addDoubleField("sum5").addDoubleField("avg5").addDoubleField("max5").addDoubleField("min5").addDateTimeField("max7").addDateTimeField("min7").addInt32Field("sum8").addInt32Field("avg8").addInt32Field("max8").addInt32Field("min8").build();
        this.aggCoder = SchemaCoder.of((Schema)this.aggPartType, (SerializableFunction)SerializableFunctions.identity(), (SerializableFunction)SerializableFunctions.identity());
        this.outputType = this.prepareFinalSchema();
    }

    private List<KV<Row, Row>> prepareResultOfAggregationGroupByKeyFn() {
        return IntStream.range(0, 4).mapToObj(i -> KV.of((Object)Row.withSchema((Schema)this.keyType).addValues(new Object[]{((Row)inputRows.get(i)).getInt32(0)}).build(), (Object)((Row)inputRows.get(i)))).collect(Collectors.toList());
    }

    private List<KV<Row, Row>> prepareResultOfAggregationCombineFn() throws ParseException {
        return Arrays.asList(KV.of((Object)Row.withSchema((Schema)this.keyType).addValues(new Object[]{((Row)inputRows.get(0)).getInt32(0)}).build(), (Object)Row.withSchema((Schema)this.aggPartType).addValues(new Object[]{4L, 10000L, 2500L, 4000L, 1000L, (short)10, (short)2, (short)4, (short)1, (byte)10, (byte)2, (byte)4, (byte)1, Float.valueOf(10.0f), Float.valueOf(2.5f), Float.valueOf(4.0f), Float.valueOf(1.0f), 10.0, 2.5, 4.0, 1.0, FORMAT.parseDateTime("2017-01-01 02:04:03"), FORMAT.parseDateTime("2017-01-01 01:01:03"), 10, 2, 4, 1}).build()));
    }

    private Schema prepareFinalSchema() {
        return Schema.builder().addInt32Field("f_int").addInt64Field("count").addInt64Field("sum1").addInt64Field("avg1").addInt64Field("max1").addInt64Field("min1").addInt16Field("sum2").addInt16Field("avg2").addInt16Field("max2").addInt16Field("min2").addByteField("sum3").addByteField("avg3").addByteField("max3").addByteField("min3").addFloatField("sum4").addFloatField("avg4").addFloatField("max4").addFloatField("min4").addDoubleField("sum5").addDoubleField("avg5").addDoubleField("max5").addDoubleField("min5").addDateTimeField("max7").addDateTimeField("min7").addInt32Field("sum8").addInt32Field("avg8").addInt32Field("max8").addInt32Field("min8").build();
    }

    private Row prepareResultOfMergeAggregationRow() throws ParseException {
        return Row.withSchema((Schema)this.outputType).addValues(new Object[]{1, 4L, 10000L, 2500L, 4000L, 1000L, (short)10, (short)2, (short)4, (short)1, (byte)10, (byte)2, (byte)4, (byte)1, Float.valueOf(10.0f), Float.valueOf(2.5f), Float.valueOf(4.0f), Float.valueOf(1.0f), 10.0, 2.5, 4.0, 1.0, FORMAT.parseDateTime("2017-01-01 02:04:03"), FORMAT.parseDateTime("2017-01-01 01:01:03"), 10, 2, 4, 1}).build();
    }
}

