package org.apache.pinot.core.query.reduce;

import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/core/query/reduce/PostAggregationHandlerTest.class */
public class PostAggregationHandlerTest {
    @Test
    public void testPostAggregation() {
        QueryContext queryContext = QueryContextConverterUtils.getQueryContext("SELECT SUM(m1) FROM testTable");
        DataSchema dataSchema = new DataSchema(new String[]{"sum(m1)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE});
        PostAggregationHandler postAggregationHandler = new PostAggregationHandler(queryContext, dataSchema);
        Assert.assertEquals(postAggregationHandler.getResultDataSchema(), dataSchema);
        Assert.assertEquals(postAggregationHandler.getResult(new Object[]{Double.valueOf(1.0d)}), new Object[]{Double.valueOf(1.0d)});
        Assert.assertEquals(postAggregationHandler.getResult(new Object[]{Double.valueOf(2.0d)}), new Object[]{Double.valueOf(2.0d)});
        QueryContext queryContext2 = QueryContextConverterUtils.getQueryContext("SELECT d1, SUM(m1) FROM testTable GROUP BY d1");
        DataSchema dataSchema2 = new DataSchema(new String[]{"d1", "sum(m1)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE});
        PostAggregationHandler postAggregationHandler2 = new PostAggregationHandler(queryContext2, dataSchema2);
        Assert.assertEquals(postAggregationHandler2.getResultDataSchema(), dataSchema2);
        Assert.assertEquals(postAggregationHandler2.getResult(new Object[]{1, Double.valueOf(2.0d)}), new Object[]{1, Double.valueOf(2.0d)});
        Assert.assertEquals(postAggregationHandler2.getResult(new Object[]{3, Double.valueOf(4.0d)}), new Object[]{3, Double.valueOf(4.0d)});
        PostAggregationHandler postAggregationHandler3 = new PostAggregationHandler(QueryContextConverterUtils.getQueryContext("SELECT SUM(m1), d2 FROM testTable GROUP BY d1, d2"), new DataSchema(new String[]{"d1", "d2", "sum(m1)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.DOUBLE}));
        DataSchema resultDataSchema = postAggregationHandler3.getResultDataSchema();
        Assert.assertEquals(resultDataSchema.size(), 2);
        Assert.assertEquals(resultDataSchema.getColumnNames(), new String[]{"sum(m1)", "d2"});
        Assert.assertEquals(resultDataSchema.getColumnDataTypes(), new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.LONG});
        Assert.assertEquals(postAggregationHandler3.getResult(new Object[]{1, 2L, Double.valueOf(3.0d)}), new Object[]{Double.valueOf(3.0d), 2L});
        Assert.assertEquals(postAggregationHandler3.getResult(new Object[]{4, 5L, Double.valueOf(6.0d)}), new Object[]{Double.valueOf(6.0d), 5L});
        PostAggregationHandler postAggregationHandler4 = new PostAggregationHandler(QueryContextConverterUtils.getQueryContext("SELECT SUM(m1), d2 FROM testTable GROUP BY d1, d2 ORDER BY MAX(m1)"), new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m1)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE}));
        DataSchema resultDataSchema2 = postAggregationHandler4.getResultDataSchema();
        Assert.assertEquals(resultDataSchema2.size(), 2);
        Assert.assertEquals(resultDataSchema2.getColumnNames(), new String[]{"sum(m1)", "d2"});
        Assert.assertEquals(resultDataSchema2.getColumnDataTypes(), new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.LONG});
        Assert.assertEquals(postAggregationHandler4.getResult(new Object[]{1, 2L, Double.valueOf(3.0d), Double.valueOf(4.0d)}), new Object[]{Double.valueOf(3.0d), 2L});
        Assert.assertEquals(postAggregationHandler4.getResult(new Object[]{5, 6L, Double.valueOf(7.0d), Double.valueOf(8.0d)}), new Object[]{Double.valueOf(7.0d), 6L});
        PostAggregationHandler postAggregationHandler5 = new PostAggregationHandler(QueryContextConverterUtils.getQueryContext("SELECT SUM(m1) + MAX(m2) FROM testTable"), new DataSchema(new String[]{"sum(m1)", "max(m2)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE}));
        DataSchema resultDataSchema3 = postAggregationHandler5.getResultDataSchema();
        Assert.assertEquals(resultDataSchema3.size(), 1);
        Assert.assertEquals(resultDataSchema3.getColumnName(0), "plus(sum(m1),max(m2))");
        Assert.assertEquals(resultDataSchema3.getColumnDataType(0), DataSchema.ColumnDataType.DOUBLE);
        Assert.assertEquals(postAggregationHandler5.getResult(new Object[]{Double.valueOf(1.0d), Double.valueOf(2.0d)}), new Object[]{Double.valueOf(3.0d)});
        Assert.assertEquals(postAggregationHandler5.getResult(new Object[]{Double.valueOf(3.0d), Double.valueOf(4.0d)}), new Object[]{Double.valueOf(7.0d)});
        PostAggregationHandler postAggregationHandler6 = new PostAggregationHandler(QueryContextConverterUtils.getQueryContext("SELECT (SUM(m1) + MAX(m2) - d1) / 2, d2 FROM testTable GROUP BY d1, d2 ORDER BY MAX(m1)"), new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)", "max(m1)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE}));
        DataSchema resultDataSchema4 = postAggregationHandler6.getResultDataSchema();
        Assert.assertEquals(resultDataSchema4.size(), 2);
        Assert.assertEquals(resultDataSchema4.getColumnNames(), new String[]{"divide(minus(plus(sum(m1),max(m2)),d1),'2')", "d2"});
        Assert.assertEquals(resultDataSchema4.getColumnDataTypes(), new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.LONG});
        Assert.assertEquals(postAggregationHandler6.getResult(new Object[]{1, 2L, Double.valueOf(3.0d), Double.valueOf(4.0d), Double.valueOf(5.0d)}), new Object[]{Double.valueOf(3.0d), 2L});
        Assert.assertEquals(postAggregationHandler6.getResult(new Object[]{6, 7L, Double.valueOf(8.0d), Double.valueOf(9.0d), Double.valueOf(10.0d)}), new Object[]{Double.valueOf(5.5d), 7L});
    }
}
