package org.apache.beam.sdk.extensions.sql.impl.schema.transform;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.coders.BeamRecordCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamAggregationTransforms;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.values.BeamRecord;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.SqlKind;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.type.BasicSqlType;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.ImmutableBitSet;
import org.junit.Rule;
import org.junit.Test;

/* loaded from: input_file:org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamAggregationTransformTest.class */
public class BeamAggregationTransformTest extends BeamTransformBaseTest {

    @Rule
    public TestPipeline p = TestPipeline.create();
    private List<AggregateCall> aggCalls;
    private BeamRecordSqlType keyType;
    private BeamRecordSqlType aggPartType;
    private BeamRecordSqlType outputType;
    private BeamRecordCoder inRecordCoder;
    private BeamRecordCoder keyCoder;
    private BeamRecordCoder aggCoder;
    private BeamRecordCoder outRecordCoder;

    @Test
    public void testCountPerElementBasic() throws ParseException {
        setupEnvironment();
        PCollection coder = this.p.apply(Create.of(inputRows)).apply("exGroupBy", WithKeys.of(new BeamAggregationTransforms.AggregationGroupByKeyFn(-1, ImmutableBitSet.of(new int[]{0})))).setCoder(KvCoder.of(this.keyCoder, this.inRecordCoder));
        PCollection coder2 = coder.apply("groupBy", GroupByKey.create()).setCoder(KvCoder.of(this.keyCoder, IterableCoder.of(this.inRecordCoder))).apply("aggregation", Combine.groupedValues(new BeamAggregationTransforms.AggregationAdaptor(this.aggCalls, inputRowType))).setCoder(KvCoder.of(this.keyCoder, this.aggCoder));
        PCollection apply = coder2.apply("mergeRecord", ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord(this.outputType, this.aggCalls, -1)));
        apply.setCoder(this.outRecordCoder);
        PAssert.that(coder).containsInAnyOrder(prepareResultOfAggregationGroupByKeyFn());
        PAssert.that(coder2).containsInAnyOrder(prepareResultOfAggregationCombineFn());
        PAssert.that(apply).containsInAnyOrder(new BeamRecord[]{prepareResultOfMergeAggregationRecord()});
        this.p.run();
    }

    private void setupEnvironment() {
        prepareAggregationCalls();
        prepareTypeAndCoder();
    }

    private void prepareAggregationCalls() {
        this.aggCalls = new ArrayList();
        this.aggCalls.add(new AggregateCall(new SqlCountAggFunction(), false, Arrays.asList(new Integer[0]), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "count"));
        this.aggCalls.add(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT)), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "sum1"));
        this.aggCalls.add(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "avg1"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "max1"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(1), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), "min1"));
        this.aggCalls.add(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT)), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "sum2"));
        this.aggCalls.add(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "avg2"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "max2"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(2), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), "min2"));
        this.aggCalls.add(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT)), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "sum3"));
        this.aggCalls.add(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "avg3"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "max3"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(3), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), "min3"));
        this.aggCalls.add(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT)), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "sum4"));
        this.aggCalls.add(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "avg4"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "max4"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(4), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), "min4"));
        this.aggCalls.add(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE)), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "sum5"));
        this.aggCalls.add(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "avg5"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "max5"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(5), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), "min5"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(7), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), "max7"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(7), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), "min7"));
        this.aggCalls.add(new AggregateCall(new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "sum8"));
        this.aggCalls.add(new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "avg8"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "max8"));
        this.aggCalls.add(new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, Arrays.asList(8), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), "min8"));
    }

    private void prepareTypeAndCoder() {
        this.inRecordCoder = inputRowType.getRecordCoder();
        this.keyType = initTypeOfSqlRow(Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER)));
        this.keyCoder = this.keyType.getRecordCoder();
        this.aggPartType = initTypeOfSqlRow(Arrays.asList(KV.of("count", SqlTypeName.BIGINT), KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT), KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT), KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT), KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT), KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT), KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT), KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT), KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT), KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE), KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE), KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP), KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER), KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER)));
        this.aggCoder = this.aggPartType.getRecordCoder();
        this.outputType = prepareFinalRowType();
        this.outRecordCoder = this.outputType.getRecordCoder();
    }

    private List<KV<BeamRecord, BeamRecord>> prepareResultOfAggregationGroupByKeyFn() {
        return Arrays.asList(KV.of(new BeamRecord(this.keyType, Arrays.asList(inputRows.get(0).getInteger(0))), inputRows.get(0)), KV.of(new BeamRecord(this.keyType, Arrays.asList(inputRows.get(1).getInteger(0))), inputRows.get(1)), KV.of(new BeamRecord(this.keyType, Arrays.asList(inputRows.get(2).getInteger(0))), inputRows.get(2)), KV.of(new BeamRecord(this.keyType, Arrays.asList(inputRows.get(3).getInteger(0))), inputRows.get(3)));
    }

    private List<KV<BeamRecord, BeamRecord>> prepareResultOfAggregationCombineFn() throws ParseException {
        return Arrays.asList(KV.of(new BeamRecord(this.keyType, Arrays.asList(inputRows.get(0).getInteger(0))), new BeamRecord(this.aggPartType, Arrays.asList(4L, 10000L, 2500L, 4000L, 1000L, (short) 10, (short) 2, (short) 4, (short) 1, (byte) 10, (byte) 2, (byte) 4, (byte) 1, Float.valueOf(10.0f), Float.valueOf(2.5f), Float.valueOf(4.0f), Float.valueOf(1.0f), Double.valueOf(10.0d), Double.valueOf(2.5d), Double.valueOf(4.0d), Double.valueOf(1.0d), format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"), 10, 2, 4, 1))));
    }

    private BeamRecordSqlType prepareFinalRowType() {
        RelDataTypeFactory.FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
        for (KV kv : Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER), KV.of("count", SqlTypeName.BIGINT), KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT), KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT), KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT), KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT), KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT), KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT), KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT), KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT), KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE), KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE), KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP), KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER), KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER))) {
            builder.add((String) kv.getKey(), (SqlTypeName) kv.getValue());
        }
        return CalciteUtils.toBeamRowType(builder.build());
    }

    private BeamRecord prepareResultOfMergeAggregationRecord() throws ParseException {
        return new BeamRecord(this.outputType, Arrays.asList(1, 4L, 10000L, 2500L, 4000L, 1000L, (short) 10, (short) 2, (short) 4, (short) 1, (byte) 10, (byte) 2, (byte) 4, (byte) 1, Float.valueOf(10.0f), Float.valueOf(2.5f), Float.valueOf(4.0f), Float.valueOf(1.0f), Double.valueOf(10.0d), Double.valueOf(2.5d), Double.valueOf(4.0d), Double.valueOf(1.0d), format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"), 10, 2, 4, 1));
    }
}
