/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.schema.transform;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.coders.BeamRecordCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner;
import org.apache.beam.sdk.extensions.sql.impl.schema.transform.BeamTransformBaseTest;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamAggregationTransforms;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.values.BeamRecord;
import org.apache.beam.sdk.values.BeamRecordType;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.SqlAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.SqlKind;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.type.BasicSqlType;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.ImmutableBitSet;
import org.junit.Rule;
import org.junit.Test;

public class BeamAggregationTransformTest
extends BeamTransformBaseTest {
    @Rule
    public TestPipeline p = TestPipeline.create();
    private List<AggregateCall> aggCalls;
    private BeamRecordSqlType keyType;
    private BeamRecordSqlType aggPartType;
    private BeamRecordSqlType outputType;
    private BeamRecordCoder inRecordCoder;
    private BeamRecordCoder keyCoder;
    private BeamRecordCoder aggCoder;
    private BeamRecordCoder outRecordCoder;

    @Test
    public void testCountPerElementBasic() throws ParseException {
        this.setupEnvironment();
        PCollection input = (PCollection)this.p.apply((PTransform)Create.of((Iterable)inputRows));
        PCollection exGroupByStream = ((PCollection)input.apply("exGroupBy", (PTransform)WithKeys.of((SerializableFunction)new BeamAggregationTransforms.AggregationGroupByKeyFn(-1, ImmutableBitSet.of((int[])new int[]{0}))))).setCoder((Coder)KvCoder.of((Coder)this.keyCoder, (Coder)this.inRecordCoder));
        PCollection groupedStream = ((PCollection)exGroupByStream.apply("groupBy", (PTransform)GroupByKey.create())).setCoder((Coder)KvCoder.of((Coder)this.keyCoder, (Coder)IterableCoder.of((Coder)this.inRecordCoder)));
        PCollection aggregatedStream = ((PCollection)groupedStream.apply("aggregation", (PTransform)Combine.groupedValues((CombineFnBase.GlobalCombineFn)new BeamAggregationTransforms.AggregationAdaptor(this.aggCalls, inputRowType)))).setCoder((Coder)KvCoder.of((Coder)this.keyCoder, (Coder)this.aggCoder));
        PCollection mergedStream = (PCollection)aggregatedStream.apply("mergeRecord", (PTransform)ParDo.of((DoFn)new BeamAggregationTransforms.MergeAggregationRecord(this.outputType, this.aggCalls, -1)));
        mergedStream.setCoder((Coder)this.outRecordCoder);
        PAssert.that((PCollection)exGroupByStream).containsInAnyOrder(this.prepareResultOfAggregationGroupByKeyFn());
        PAssert.that((PCollection)aggregatedStream).containsInAnyOrder(this.prepareResultOfAggregationCombineFn());
        PAssert.that((PCollection)mergedStream).containsInAnyOrder((Object[])new BeamRecord[]{this.prepareResultOfMergeAggregationRecord()});
        this.p.run();
    }

    private void setupEnvironment() {
        this.prepareAggregationCalls();
        this.prepareTypeAndCoder();
    }

    private void prepareAggregationCalls() {
        this.aggCalls = new ArrayList<AggregateCall>();
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlCountAggFunction(), false, Arrays.asList(new Integer[0]), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "count"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlSumAggFunction((RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT)), false, Arrays.asList(1), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "sum1"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(1), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "avg1"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(1), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "max1"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(1), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "min1"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlSumAggFunction((RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT)), false, Arrays.asList(2), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "sum2"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(2), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "avg2"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(2), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "max2"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(2), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "min2"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlSumAggFunction((RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT)), false, Arrays.asList(3), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "sum3"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(3), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "avg3"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(3), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "max3"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(3), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "min3"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlSumAggFunction((RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT)), false, Arrays.asList(4), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "sum4"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(4), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "avg4"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(4), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "max4"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(4), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "min4"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlSumAggFunction((RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE)), false, Arrays.asList(5), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "sum5"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(5), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "avg5"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(5), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "max5"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(5), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "min5"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(7), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), "max7"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(7), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), "min7"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlSumAggFunction((RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)), false, Arrays.asList(8), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "sum8"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(8), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "avg8"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(8), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "max8"));
        this.aggCalls.add(new AggregateCall((SqlAggFunction)new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(8), (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "min8"));
    }

    private void prepareTypeAndCoder() {
        this.inRecordCoder = inputRowType.getRecordCoder();
        this.keyType = BeamAggregationTransformTest.initTypeOfSqlRow(Arrays.asList(KV.of((Object)"f_int", (Object)SqlTypeName.INTEGER)));
        this.keyCoder = this.keyType.getRecordCoder();
        this.aggPartType = BeamAggregationTransformTest.initTypeOfSqlRow(Arrays.asList(KV.of((Object)"count", (Object)SqlTypeName.BIGINT), KV.of((Object)"sum1", (Object)SqlTypeName.BIGINT), KV.of((Object)"avg1", (Object)SqlTypeName.BIGINT), KV.of((Object)"max1", (Object)SqlTypeName.BIGINT), KV.of((Object)"min1", (Object)SqlTypeName.BIGINT), KV.of((Object)"sum2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"avg2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"max2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"min2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"sum3", (Object)SqlTypeName.TINYINT), KV.of((Object)"avg3", (Object)SqlTypeName.TINYINT), KV.of((Object)"max3", (Object)SqlTypeName.TINYINT), KV.of((Object)"min3", (Object)SqlTypeName.TINYINT), KV.of((Object)"sum4", (Object)SqlTypeName.FLOAT), KV.of((Object)"avg4", (Object)SqlTypeName.FLOAT), KV.of((Object)"max4", (Object)SqlTypeName.FLOAT), KV.of((Object)"min4", (Object)SqlTypeName.FLOAT), KV.of((Object)"sum5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"avg5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"max5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"min5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"max7", (Object)SqlTypeName.TIMESTAMP), KV.of((Object)"min7", (Object)SqlTypeName.TIMESTAMP), KV.of((Object)"sum8", (Object)SqlTypeName.INTEGER), KV.of((Object)"avg8", (Object)SqlTypeName.INTEGER), KV.of((Object)"max8", (Object)SqlTypeName.INTEGER), KV.of((Object)"min8", (Object)SqlTypeName.INTEGER)));
        this.aggCoder = this.aggPartType.getRecordCoder();
        this.outputType = this.prepareFinalRowType();
        this.outRecordCoder = this.outputType.getRecordCoder();
    }

    private List<KV<BeamRecord, BeamRecord>> prepareResultOfAggregationGroupByKeyFn() {
        return Arrays.asList(KV.of((Object)new BeamRecord((BeamRecordType)this.keyType, Arrays.asList(((BeamRecord)inputRows.get(0)).getInteger(0))), inputRows.get(0)), KV.of((Object)new BeamRecord((BeamRecordType)this.keyType, Arrays.asList(((BeamRecord)inputRows.get(1)).getInteger(0))), inputRows.get(1)), KV.of((Object)new BeamRecord((BeamRecordType)this.keyType, Arrays.asList(((BeamRecord)inputRows.get(2)).getInteger(0))), inputRows.get(2)), KV.of((Object)new BeamRecord((BeamRecordType)this.keyType, Arrays.asList(((BeamRecord)inputRows.get(3)).getInteger(0))), inputRows.get(3)));
    }

    private List<KV<BeamRecord, BeamRecord>> prepareResultOfAggregationCombineFn() throws ParseException {
        return Arrays.asList(KV.of((Object)new BeamRecord((BeamRecordType)this.keyType, Arrays.asList(((BeamRecord)inputRows.get(0)).getInteger(0))), (Object)new BeamRecord((BeamRecordType)this.aggPartType, Arrays.asList(4L, 10000L, 2500L, 4000L, 1000L, (short)10, (short)2, (short)4, (short)1, (byte)10, (byte)2, (byte)4, (byte)1, Float.valueOf(10.0f), Float.valueOf(2.5f), Float.valueOf(4.0f), Float.valueOf(1.0f), 10.0, 2.5, 4.0, 1.0, format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"), 10, 2, 4, 1))));
    }

    private BeamRecordSqlType prepareFinalRowType() {
        RelDataTypeFactory.FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
        List<KV> columnMetadata = Arrays.asList(KV.of((Object)"f_int", (Object)SqlTypeName.INTEGER), KV.of((Object)"count", (Object)SqlTypeName.BIGINT), KV.of((Object)"sum1", (Object)SqlTypeName.BIGINT), KV.of((Object)"avg1", (Object)SqlTypeName.BIGINT), KV.of((Object)"max1", (Object)SqlTypeName.BIGINT), KV.of((Object)"min1", (Object)SqlTypeName.BIGINT), KV.of((Object)"sum2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"avg2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"max2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"min2", (Object)SqlTypeName.SMALLINT), KV.of((Object)"sum3", (Object)SqlTypeName.TINYINT), KV.of((Object)"avg3", (Object)SqlTypeName.TINYINT), KV.of((Object)"max3", (Object)SqlTypeName.TINYINT), KV.of((Object)"min3", (Object)SqlTypeName.TINYINT), KV.of((Object)"sum4", (Object)SqlTypeName.FLOAT), KV.of((Object)"avg4", (Object)SqlTypeName.FLOAT), KV.of((Object)"max4", (Object)SqlTypeName.FLOAT), KV.of((Object)"min4", (Object)SqlTypeName.FLOAT), KV.of((Object)"sum5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"avg5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"max5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"min5", (Object)SqlTypeName.DOUBLE), KV.of((Object)"max7", (Object)SqlTypeName.TIMESTAMP), KV.of((Object)"min7", (Object)SqlTypeName.TIMESTAMP), KV.of((Object)"sum8", (Object)SqlTypeName.INTEGER), KV.of((Object)"avg8", (Object)SqlTypeName.INTEGER), KV.of((Object)"max8", (Object)SqlTypeName.INTEGER), KV.of((Object)"min8", (Object)SqlTypeName.INTEGER));
        for (KV cm : columnMetadata) {
            builder.add((String)cm.getKey(), (SqlTypeName)cm.getValue());
        }
        return CalciteUtils.toBeamRowType((RelDataType)builder.build());
    }

    private BeamRecord prepareResultOfMergeAggregationRecord() throws ParseException {
        return new BeamRecord((BeamRecordType)this.outputType, Arrays.asList(1, 4L, 10000L, 2500L, 4000L, 1000L, (short)10, (short)2, (short)4, (short)1, (byte)10, (byte)2, (byte)4, (byte)1, Float.valueOf(10.0f), Float.valueOf(2.5f), Float.valueOf(4.0f), Float.valueOf(1.0f), 10.0, 2.5, 4.0, 1.0, format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"), 10, 2, 4, 1));
    }
}

