package org.apache.drill.exec.fn.impl;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.math.BigDecimal;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.drill.categories.EasyOutOfMemory;
import org.apache.drill.categories.OperatorTest;
import org.apache.drill.categories.PlannerTest;
import org.apache.drill.categories.SqlFunctionTest;
import org.apache.drill.categories.UnlikelyTest;
import org.apache.drill.common.exceptions.UserRemoteException;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.common.types.Types;
import org.apache.drill.exec.physical.rowSet.TestFillEmpties;
import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.exec.proto.UserBitShared;
import org.apache.drill.exec.record.RecordBatchLoader;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.metadata.SchemaBuilder;
import org.apache.drill.exec.rpc.user.QueryDataBatch;
import org.apache.drill.exec.util.Text;
import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableList;
import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableMap;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.apache.drill.shaded.guava.com.google.common.collect.Maps;
import org.apache.drill.test.ClusterFixture;
import org.apache.drill.test.ClusterTest;
import org.apache.drill.test.TestBuilder;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;

@Category({SqlFunctionTest.class, OperatorTest.class, PlannerTest.class, EasyOutOfMemory.class})
/* loaded from: input_file:org/apache/drill/exec/fn/impl/TestAggregateFunctions.class */
public class TestAggregateFunctions extends ClusterTest {

    @Rule
    public ExpectedException thrown = ExpectedException.none();

    @BeforeClass
    public static void setUp() throws Exception {
        startCluster(ClusterFixture.builder(dirTestWatcher));
        dirTestWatcher.copyResourceToRoot(Paths.get("agg", new String[0]));
    }

    @Test
    public void testCountOnNullableColumn() throws Exception {
        testBuilder().sqlQuery("select count(t.x.y)  as cnt1, count(`integer`) as cnt2 from cp.`jsoninput/input2.json` t").ordered().baselineColumns("cnt1", "cnt2").baselineValues(3L, 4L).go();
    }

    @Test
    public void testCountDistinctOnBoolColumn() throws Exception {
        testBuilder().sqlQuery("select count(distinct `bool_val`) as cnt from `sys`.`options_old`").ordered().baselineColumns("cnt").baselineValues(2L).go();
    }

    @Test
    public void testMaxWithZeroInput() throws Exception {
        try {
            client.alterSession("planner.enable_decimal_data_type", false);
            testBuilder().sqlQuery("select max(employee_id * 0.0) as max_val from cp.`employee.json`").unOrdered().baselineColumns("max_val").baselineValues(Double.valueOf(0.0d)).go();
            client.resetSession("planner.enable_decimal_data_type");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            throw th;
        }
    }

    @Test
    @Ignore
    public void testDrill2092() throws Exception {
        testBuilder().sqlQuery("select a1, b1, count(distinct c1) as dist1, \nsum(c1) as sum1, count(c1) as cnt1, count(*) as cnt \nfrom cp.`agg/bugs/drill2092/input.json` \ngroup by a1, b1 order by a1, b1").ordered().optionSettingQueriesForTestQuery("alter system set `planner.enable_hashjoin` = true").sqlBaselineQuery("select case when columns[0]='null' then cast(null as bigint) else cast(columns[0] as bigint) end as a1, \ncase when columns[1]='null' then cast(null as bigint) else cast(columns[1] as bigint) end as b1, \ncase when columns[2]='null' then cast(null as bigint) else cast(columns[2] as bigint) end as dist1, \ncase when columns[3]='null' then cast(null as bigint) else cast(columns[3] as bigint) end as sum1, \ncase when columns[4]='null' then cast(null as bigint) else cast(columns[4] as bigint) end as cnt1, \ncase when columns[5]='null' then cast(null as bigint) else cast(columns[5] as bigint) end as cnt \nfrom cp.`agg/bugs/drill2092/result.tsv`").build().run();
        testBuilder().sqlQuery("select a1, b1, count(distinct c1) as dist1, \nsum(c1) as sum1, count(c1) as cnt1, count(*) as cnt \nfrom cp.`agg/bugs/drill2092/input.json` \ngroup by a1, b1 order by a1, b1").ordered().optionSettingQueriesForTestQuery("alter system set `planner.enable_hashjoin` = false").sqlBaselineQuery("select case when columns[0]='null' then cast(null as bigint) else cast(columns[0] as bigint) end as a1, \ncase when columns[1]='null' then cast(null as bigint) else cast(columns[1] as bigint) end as b1, \ncase when columns[2]='null' then cast(null as bigint) else cast(columns[2] as bigint) end as dist1, \ncase when columns[3]='null' then cast(null as bigint) else cast(columns[3] as bigint) end as sum1, \ncase when columns[4]='null' then cast(null as bigint) else cast(columns[4] as bigint) end as cnt1, \ncase when columns[5]='null' then cast(null as bigint) else cast(columns[5] as bigint) end as cnt \nfrom cp.`agg/bugs/drill2092/result.tsv`").build().run();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testDrill2170() throws Exception {
        testBuilder().sqlQuery("select count(*) as cnt from cp.`tpch/orders.parquet` o inner join\n(select l_orderkey, sum(l_quantity), sum(l_extendedprice) \nfrom cp.`tpch/lineitem.parquet` \ngroup by l_orderkey order by 3 limit 100) sq \non sq.l_orderkey = o.o_orderkey").ordered().optionSettingQueriesForTestQuery("alter system set `planner.slice_target` = 1000").baselineColumns("cnt").baselineValues(100L).go();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testGBExprWithDrillFunc() throws Exception {
        testBuilder().ordered().sqlQuery("select concat(n_name, cast(n_nationkey as varchar(10))) as name, count(*) as cnt from cp.`tpch/nation.parquet` group by concat(n_name, cast(n_nationkey as varchar(10))) having concat(n_name, cast(n_nationkey as varchar(10))) > 'UNITED'order by concat(n_name, cast(n_nationkey as varchar(10)))").baselineColumns("name", "cnt").baselineValues("UNITED KINGDOM23", 1L).baselineValues("UNITED STATES24", 1L).baselineValues("VIETNAM21", 1L).build().run();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testDRILLNestedGBWithSubsetKeys() throws Exception {
        client.alterSession("planner.slice_target", 1);
        client.alterSession(PlannerSettings.MULTIPHASE.getOptionName(), false);
        testBuilder().ordered().sqlQuery(" select count(*) as cnt from (select l_partkey from\n   (select l_partkey, l_suppkey from cp.`tpch/lineitem.parquet`\n      group by l_partkey, l_suppkey) \n   group by l_partkey )").baselineColumns("cnt").baselineValues(2000L).build().run();
        client.alterSession("planner.slice_target", 1);
        client.alterSession(PlannerSettings.MULTIPHASE.getOptionName(), true);
        testBuilder().ordered().sqlQuery(" select count(*) as cnt from (select l_partkey from\n   (select l_partkey, l_suppkey from cp.`tpch/lineitem.parquet`\n      group by l_partkey, l_suppkey) \n   group by l_partkey )").baselineColumns("cnt").baselineValues(2000L).build().run();
        client.alterSession("planner.slice_target", Integer.valueOf(TestFillEmpties.ROW_COUNT));
    }

    @Test
    public void testAvgWithNullableScalarFunction() throws Exception {
        testBuilder().sqlQuery(" select avg(length(b1)) as col from cp.`jsoninput/nullable1.json`").unOrdered().baselineColumns("col").baselineValues(Double.valueOf(3.0d)).go();
    }

    @Test
    public void testCountWithAvg() throws Exception {
        testBuilder().sqlQuery("select count(a) col1, avg(b) col2 from cp.`jsoninput/nullable3.json`").unOrdered().baselineColumns("col1", "col2").baselineValues(2L, Double.valueOf(3.0d)).go();
        testBuilder().sqlQuery("select count(a) col1, avg(a) col2 from cp.`jsoninput/nullable3.json`").unOrdered().baselineColumns("col1", "col2").baselineValues(2L, Double.valueOf(1.0d)).go();
    }

    @Test
    public void testAvgOnKnownType() throws Exception {
        testBuilder().sqlQuery("select avg(cast(employee_id as bigint)) as col from cp.`employee.json`").unOrdered().baselineColumns("col").baselineValues(Double.valueOf(578.9982683982684d)).go();
    }

    @Test
    public void testStddevOnKnownType() throws Exception {
        testBuilder().sqlQuery("select stddev_samp(cast(employee_id as int)) as col from cp.`employee.json`").unOrdered().baselineColumns("col").baselineValues(Double.valueOf(333.56708470261117d)).go();
    }

    @Test
    public void testVarSampDecimal() throws Exception {
        try {
            client.alterSession("planner.enable_decimal_data_type", true);
            testBuilder().sqlQuery("select var_samp(cast(employee_id as decimal(28, 20))) as dec20,\nvar_samp(cast(employee_id as decimal(28, 0))) as dec6,\nvar_samp(cast(employee_id as integer)) as d\nfrom cp.`employee.json`").unOrdered().baselineColumns("dec20", "dec6", "d").baselineValues(new BigDecimal("111266.99999699895713760532"), new BigDecimal("111266.999997"), Double.valueOf(111266.99999699896d)).go();
            client.resetSession("planner.enable_decimal_data_type");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            throw th;
        }
    }

    @Test
    public void testVarPopDecimal() throws Exception {
        try {
            client.alterSession("planner.enable_decimal_data_type", true);
            testBuilder().sqlQuery("select var_pop(cast(employee_id as decimal(28, 20))) as dec20,\nvar_pop(cast(employee_id as decimal(28, 0))) as dec6,\nvar_pop(cast(employee_id as integer)) as d\nfrom cp.`employee.json`").unOrdered().baselineColumns("dec20", "dec6", "d").baselineValues(new BigDecimal("111170.66493206649050804895"), new BigDecimal("111170.664932"), Double.valueOf(111170.66493206649d)).go();
            client.resetSession("planner.enable_decimal_data_type");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            throw th;
        }
    }

    @Test
    public void testStddevSampDecimal() throws Exception {
        try {
            client.alterSession("planner.enable_decimal_data_type", true);
            testBuilder().sqlQuery("select stddev_samp(cast(employee_id as decimal(28, 20))) as dec20,\nstddev_samp(cast(employee_id as decimal(28, 0))) as dec6,\nstddev_samp(cast(employee_id as integer)) as d\nfrom cp.`employee.json`").unOrdered().baselineColumns("dec20", "dec6", "d").baselineValues(new BigDecimal("333.56708470261114349632"), new BigDecimal("333.567085"), Double.valueOf(333.56708470261117d)).go();
            client.resetSession("planner.enable_decimal_data_type");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            throw th;
        }
    }

    @Test
    public void testStddevPopDecimal() throws Exception {
        try {
            client.alterSession("planner.enable_decimal_data_type", true);
            testBuilder().sqlQuery("select stddev_pop(cast(employee_id as decimal(28, 20))) as dec20,\nstddev_pop(cast(employee_id as decimal(28, 0))) as dec6,\nstddev_pop(cast(employee_id as integer)) as d\nfrom cp.`employee.json`").unOrdered().baselineColumns("dec20", "dec6", "d").baselineValues(new BigDecimal("333.42265209800381903633"), new BigDecimal("333.422652"), Double.valueOf(333.4226520980038d)).go();
            client.resetSession("planner.enable_decimal_data_type");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            throw th;
        }
    }

    @Test
    public void testSumDecimal() throws Exception {
        try {
            client.alterSession("planner.enable_decimal_data_type", true);
            testBuilder().sqlQuery("select sum(cast(employee_id as decimal(9, 0))) as colDecS0,\nsum(cast(employee_id as decimal(12, 3))) as colDecS3,\nsum(cast(employee_id as integer)) as colInt\nfrom cp.`employee.json`").unOrdered().baselineColumns("colDecS0", "colDecS3", "colInt").baselineValues(BigDecimal.valueOf(668743L), new BigDecimal("668743.000"), 668743L).go();
            client.resetSession("planner.enable_decimal_data_type");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            throw th;
        }
    }

    @Test
    public void testAvgDecimal() throws Exception {
        try {
            client.alterSession("planner.enable_decimal_data_type", true);
            testBuilder().sqlQuery("select avg(cast(employee_id as decimal(28, 20))) as colDec20,\navg(cast(employee_id as decimal(28, 0))) as colDec6,\navg(cast(employee_id as integer)) as colInt\nfrom cp.`employee.json`").unOrdered().baselineColumns("colDec20", "colDec6", "colInt").baselineValues(new BigDecimal("578.99826839826839826840"), new BigDecimal("578.998268"), Double.valueOf(578.9982683982684d)).go();
            client.resetSession("planner.enable_decimal_data_type");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            throw th;
        }
    }

    @Test
    public void testSumAvgDecimalLimit0() throws Exception {
        ImmutableList of = ImmutableList.of(Pair.of(SchemaPath.getSimplePath("sum_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 38, 3)), Pair.of(SchemaPath.getSimplePath("avg_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 38, 6)), Pair.of(SchemaPath.getSimplePath("stddev_pop_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 38, 6)), Pair.of(SchemaPath.getSimplePath("stddev_samp_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 38, 6)), Pair.of(SchemaPath.getSimplePath("var_pop_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 38, 6)), Pair.of(SchemaPath.getSimplePath("var_samp_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 38, 6)), Pair.of(SchemaPath.getSimplePath("max_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 9, 3)), Pair.of(SchemaPath.getSimplePath("min_col"), Types.withPrecisionAndScale(TypeProtos.MinorType.VARDECIMAL, TypeProtos.DataMode.OPTIONAL, 9, 3)));
        try {
            client.alterSession("planner.enable_decimal_data_type", true);
            client.alterSession("planner.enable_limit0_optimization", true);
            testBuilder().sqlQuery("select\nsum(cast(employee_id as decimal(9, 3))) sum_col,\navg(cast(employee_id as decimal(9, 3))) avg_col,\nstddev_pop(cast(employee_id as decimal(9, 3))) stddev_pop_col,\nstddev_samp(cast(employee_id as decimal(9, 3))) stddev_samp_col,\nvar_pop(cast(employee_id as decimal(9, 3))) var_pop_col,\nvar_samp(cast(employee_id as decimal(9, 3))) var_samp_col,\nmax(cast(employee_id as decimal(9, 3))) max_col,\nmin(cast(employee_id as decimal(9, 3))) min_col\nfrom cp.`employee.json` limit 0").schemaBaseLine((List<Pair<SchemaPath, TypeProtos.MajorType>>) of).go();
            client.alterSession("planner.enable_limit0_optimization", false);
            testBuilder().sqlQuery("select\nsum(cast(employee_id as decimal(9, 3))) sum_col,\navg(cast(employee_id as decimal(9, 3))) avg_col,\nstddev_pop(cast(employee_id as decimal(9, 3))) stddev_pop_col,\nstddev_samp(cast(employee_id as decimal(9, 3))) stddev_samp_col,\nvar_pop(cast(employee_id as decimal(9, 3))) var_pop_col,\nvar_samp(cast(employee_id as decimal(9, 3))) var_samp_col,\nmax(cast(employee_id as decimal(9, 3))) max_col,\nmin(cast(employee_id as decimal(9, 3))) min_col\nfrom cp.`employee.json` limit 0").schemaBaseLine((List<Pair<SchemaPath, TypeProtos.MajorType>>) of).go();
            client.resetSession("planner.enable_decimal_data_type");
            client.resetSession("planner.enable_limit0_optimization");
        } catch (Throwable th) {
            client.resetSession("planner.enable_decimal_data_type");
            client.resetSession("planner.enable_limit0_optimization");
            throw th;
        }
    }

    @Test
    public void testAggGroupByWithNullDecimal() throws Exception {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(dirTestWatcher.getRootDir(), "table.json")));
        try {
            bufferedWriter.write("{\"a\": 1, \"b\": 0}");
            bufferedWriter.write("{\"b\": 2}");
            bufferedWriter.close();
            try {
                client.alterSession("planner.enable_decimal_data_type", true);
                testBuilder().sqlQuery("select sum(cast(a as decimal(9,0))) as s,\navg(cast(a as decimal(9,0))) as av,\nvar_samp(cast(a as decimal(9,0))) as varSamp,\nvar_pop(cast(a as decimal(9,0))) as varPop,\nstddev_pop(cast(a as decimal(9,0))) as stddevPop,\nstddev_samp(cast(a as decimal(9,0))) as stddevSamp,max(cast(a as decimal(9,0))) as mx,min(cast(a as decimal(9,0))) as mn from dfs.`%s` t group by a", "table.json").unOrdered().baselineColumns("s", "av", "varSamp", "varPop", "stddevPop", "stddevSamp", "mx", "mn").baselineValues(BigDecimal.valueOf(1L), new BigDecimal("1.000000"), new BigDecimal("0.000000"), new BigDecimal("0.000000"), new BigDecimal("0.000000"), new BigDecimal("0.000000"), BigDecimal.valueOf(1L), BigDecimal.valueOf(1L)).baselineValues(null, null, null, null, null, null, null, null).go();
                client.resetSession("planner.enable_decimal_data_type");
            } catch (Throwable th) {
                client.resetSession("planner.enable_decimal_data_type");
                throw th;
            }
        } catch (Throwable th2) {
            try {
                bufferedWriter.close();
            } catch (Throwable th3) {
                th2.addSuppressed(th3);
            }
            throw th2;
        }
    }

    @Test
    public void countEmptyNullableInput() throws Exception {
        testBuilder().sqlQuery("select count(employee_id) col1, avg(employee_id) col2, sum(employee_id) col3 from cp.`employee.json` where 1 = 0").unOrdered().baselineColumns("col1", "col2", "col3").baselineValues(0L, null, null).go();
    }

    @Test
    @Ignore("DRILL-4473")
    public void sumEmptyNonexistentNullableInput() throws Exception {
        testBuilder().sqlQuery("select sum(int_col) col1, sum(bigint_col) col2, sum(float4_col) col3, sum(float8_col) col4, sum(interval_year_col) col5 from cp.`employee.json` where 1 = 0").unOrdered().baselineColumns("col1", "col2", "col3", "col4", "col5").baselineValues(null, null, null, null, null).go();
    }

    @Test
    @Ignore("DRILL-4473")
    public void avgEmptyNonexistentNullableInput() throws Exception {
        testBuilder().sqlQuery("select avg(int_col) col1, avg(bigint_col) col2, avg(float4_col) col3, avg(float8_col) col4, avg(interval_year_col) col5 from cp.`employee.json` where 1 = 0").unOrdered().baselineColumns("col1", "col2", "col3", "col4", "col5").baselineValues(null, null, null, null, null).go();
    }

    @Test
    public void stddevEmptyNonexistentNullableInput() throws Exception {
        testBuilder().sqlQuery("select stddev_pop(int_col) col1, stddev_pop(bigint_col) col2, stddev_pop(float4_col) col3, stddev_pop(float8_col) col4, stddev_pop(interval_year_col) col5 from cp.`employee.json` where 1 = 0").unOrdered().baselineColumns("col1", "col2", "col3", "col4", "col5").baselineValues(null, null, null, null, null).go();
    }

    @Test
    public void minMaxEmptyNonNullableInput() throws Exception {
        QueryDataBatch queryDataBatch = queryBuilder().sql("select * from cp.`parquet/alltypes_required.parquet` limit 0").results().get(0);
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("min", new StringBuilder());
        newHashMap.put("max", new StringBuilder());
        HashMap newHashMap2 = Maps.newHashMap();
        Iterator it = queryDataBatch.getHeader().getDef().getFieldList().iterator();
        while (it.hasNext()) {
            String name = ((UserBitShared.SerializedField) it.next()).getNamePart().getName();
            if (!name.equals("col_bln")) {
                newHashMap2.put(String.format("`%s`", name), null);
                for (Map.Entry entry : newHashMap.entrySet()) {
                    ((StringBuilder) entry.getValue()).append((String) entry.getKey()).append("(").append(name).append(") ").append(name).append(",");
                }
            }
        }
        queryDataBatch.release();
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.add(newHashMap2);
        for (StringBuilder sb : newHashMap.values()) {
            sb.setLength(sb.length() - 1);
            testBuilder().sqlQuery("select %s from cp.`parquet/alltypes_required.parquet` where 1 = 0", sb.toString()).unOrdered().baselineRecords(newArrayList).go();
        }
    }

    @Test
    public void testSingleValueFunction() throws Exception {
        for (String str : Arrays.asList("cp.`parquet/alltypes_required.parquet`", "cp.`parquet/alltypes_optional.parquet`")) {
            QueryDataBatch queryDataBatch = queryBuilder().sql("select * from %s limit 1", str).results().get(0);
            HashMap hashMap = new HashMap();
            hashMap.put("single_value", new StringBuilder());
            HashMap hashMap2 = new HashMap();
            ArrayList arrayList = new ArrayList();
            RecordBatchLoader recordBatchLoader = new RecordBatchLoader(cluster.allocator());
            recordBatchLoader.load(queryDataBatch.getHeader().getDef(), queryDataBatch.getData());
            Iterator it = recordBatchLoader.getContainer().iterator();
            while (it.hasNext()) {
                VectorWrapper vectorWrapper = (VectorWrapper) it.next();
                String name = vectorWrapper.getField().getName();
                Object object = vectorWrapper.getValueVector().getAccessor().getObject(0);
                if (object instanceof Text) {
                    object = object.toString();
                }
                hashMap2.put(String.format("`%s`", name), object);
                for (Map.Entry entry : hashMap.entrySet()) {
                    ((StringBuilder) entry.getValue()).append((String) entry.getKey()).append("(").append(name).append(") ").append(name).append(",");
                }
                arrayList.add(name);
            }
            recordBatchLoader.clear();
            queryDataBatch.release();
            String join = String.join(", ", arrayList);
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(hashMap2);
            for (StringBuilder sb : hashMap.values()) {
                sb.setLength(sb.length() - 1);
                testBuilder().sqlQuery("select %s from (select %s from %s limit 1)", sb.toString(), join, str).unOrdered().baselineRecords(arrayList2).go();
            }
        }
    }

    @Test
    public void testHashAggSingleValueFunction() throws Exception {
        for (String str : Arrays.asList("cp.`parquet/alltypes_required.parquet`", "cp.`parquet/alltypes_optional.parquet`")) {
            Map<String, Object> baselineRecords = getBaselineRecords(str);
            try {
                for (Boolean bool : Arrays.asList(true, false)) {
                    for (Map.Entry<String, Object> entry : baselineRecords.entrySet()) {
                        String format = String.format("`%s`", entry.getKey());
                        if (bool.booleanValue() || !format.startsWith("`col_intrvl")) {
                            client.alterSession(PlannerSettings.STREAMAGG.getOptionName(), bool);
                            testBuilder().sqlQuery("select single_value(t.%1$s) as %1$s\nfrom (select %1$s from %2$s limit 1) t group by t.%1$s", format, str).ordered().baselineRecords(Collections.singletonList(ImmutableMap.of(format, entry.getValue()))).go();
                        }
                    }
                }
                client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
            } catch (Throwable th) {
                client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
                throw th;
            }
        }
    }

    private Map<String, Object> getBaselineRecords(String str) throws Exception {
        QueryDataBatch queryDataBatch = queryBuilder().sql("select * from %s limit 1", str).results().get(0);
        HashMap hashMap = new HashMap();
        RecordBatchLoader recordBatchLoader = new RecordBatchLoader(cluster.allocator());
        recordBatchLoader.load(queryDataBatch.getHeader().getDef(), queryDataBatch.getData());
        Iterator it = recordBatchLoader.getContainer().iterator();
        while (it.hasNext()) {
            VectorWrapper vectorWrapper = (VectorWrapper) it.next();
            String name = vectorWrapper.getField().getName();
            Object object = vectorWrapper.getValueVector().getAccessor().getObject(0);
            if (object instanceof Text) {
                object = object.toString();
            }
            hashMap.put(name, object);
        }
        recordBatchLoader.clear();
        queryDataBatch.release();
        return hashMap;
    }

    @Test
    public void testSingleValueWithComplexInput() throws Exception {
        testBuilder().sqlQuery("select single_value(a) as any_a, single_value(f) as any_f, single_value(m) as any_m,single_value(p) as any_p from (select * from cp.`store/json/test_anyvalue.json` limit 1)").unOrdered().baselineColumns("any_a", "any_f", "any_m", "any_p").baselineValues(TestBuilder.listOf(TestBuilder.mapOf("b", 10L, "c", 15L), TestBuilder.mapOf("b", 20L, "c", 45L)), TestBuilder.listOf(TestBuilder.mapOf("g", TestBuilder.mapOf("h", TestBuilder.listOf(TestBuilder.mapOf("k", 10L), TestBuilder.mapOf("k", 20L))))), TestBuilder.listOf(TestBuilder.mapOf("n", TestBuilder.listOf(1L, 2L, 3L))), TestBuilder.mapOf("q", TestBuilder.listOf(27L, 28L, 29L))).go();
    }

    @Test
    public void testSingleValueWithMultipleValuesInputsAllTypes() throws Exception {
        for (String str : Arrays.asList("cp.`parquet/alltypes_required.parquet`", "cp.`parquet/alltypes_optional.parquet`")) {
            QueryDataBatch queryDataBatch = queryBuilder().sql("select * from %s limit 1", str).results().get(0);
            RecordBatchLoader recordBatchLoader = new RecordBatchLoader(cluster.allocator());
            recordBatchLoader.load(queryDataBatch.getHeader().getDef(), queryDataBatch.getData());
            List list = (List) StreamSupport.stream(recordBatchLoader.getContainer().spliterator(), false).map(vectorWrapper -> {
                return vectorWrapper.getField().getName();
            }).collect(Collectors.toList());
            recordBatchLoader.clear();
            queryDataBatch.release();
            Iterator it = list.iterator();
            while (it.hasNext()) {
                try {
                    run("select single_value(t.%1$s) as %1$s from %2$s t", (String) it.next(), str);
                } catch (UserRemoteException e) {
                    Assert.assertTrue("No expected current \"FUNCTION ERROR\" and/or \"Input for single_value function has more than one row\"", e.getMessage().matches("^FUNCTION ERROR(.|\\n)*Input for single_value function has more than one row(.|\\n)*"));
                }
            }
        }
    }

    @Test
    public void testSingleValueWithMultipleComplexInputs() throws Exception {
        this.thrown.expect(UserRemoteException.class);
        this.thrown.expectMessage(CoreMatchers.containsString("FUNCTION ERROR"));
        this.thrown.expectMessage(CoreMatchers.containsString("Input for single_value function has more than one row"));
        run("select single_value(t1.a) from cp.`store/json/test_anyvalue.json` t1", new Object[0]);
    }

    @Test
    public void drill3069() throws Exception {
        testBuilder().sqlQuery("select max(foo) col1 from dfs.`agg/bugs/drill3069` where foo = %d", 2).unOrdered().baselineColumns("col1").baselineValues(2L).go();
        testBuilder().sqlQuery("select max(foo) col1 from dfs.`agg/bugs/drill3069` where foo = %d", 4).unOrdered().baselineColumns("col1").baselineValues(4L).go();
        testBuilder().sqlQuery("select max(foo) col1 from dfs.`agg/bugs/drill3069` where foo = %d", 6).unOrdered().baselineColumns("col1").baselineValues(6L).go();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testPushFilterPastAgg() throws Exception {
        queryBuilder().sql(" select cnt  from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey)  where n_regionkey = 2 ").planMatcher().include("(?s)(StreamAgg|HashAgg).*Filter").exclude("(?s)Filter.*(StreamAgg|HashAgg)").match();
        testBuilder().sqlQuery(" select cnt  from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey)  where n_regionkey = 2 ").unOrdered().baselineColumns("cnt").baselineValues(5L).build().run();
        queryBuilder().sql(" select count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey  having n_regionkey = 2 ").planMatcher().include("(?s)(StreamAgg|HashAgg).*Filter").exclude("(?s)Filter.*(StreamAgg|HashAgg)").match();
        testBuilder().sqlQuery(" select cnt  from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey)  where n_regionkey = 2 ").unOrdered().baselineColumns("cnt").baselineValues(5L).build().run();
    }

    @Test
    public void testPushFilterInExprPastAgg() throws Exception {
        queryBuilder().sql(" select cnt  from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey)  where n_regionkey + 100 - 100 = 2 ").planMatcher().include("(?s)(StreamAgg|HashAgg).*Filter").exclude("(?s)Filter.*(StreamAgg|HashAgg)").match();
        testBuilder().sqlQuery(" select cnt  from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey)  where n_regionkey + 100 - 100 = 2 ").unOrdered().baselineColumns("cnt").baselineValues(5L).build().run();
    }

    @Test
    public void testNegPushFilterInExprPastAgg() throws Exception {
        queryBuilder().sql(" select cnt  from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey)  where cnt + 100 - 100 = 5 ").planMatcher().include("(?s)Filter(?!StreamAgg|!HashAgg)").exclude("(?s)(StreamAgg|HashAgg).*Filter").match();
        queryBuilder().sql(" select cnt  from (select n_regionkey, count(*) cnt from cp.`tpch/nation.parquet` group by n_regionkey)  where cnt + n_regionkey = 5 ").planMatcher().include("(?s)Filter(?!StreamAgg|!HashAgg)").exclude("(?s)(StreamAgg|HashAgg).*Filter").match();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testGroupBySystemFuncSchemaTable() throws Exception {
        queryBuilder().sql("select count(*) as cnt from sys.version group by CURRENT_DATE").planMatcher().include("(?s)(StreamAgg|HashAgg)").match();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testGroupBySystemFuncFileSystemTable() throws Exception {
        testBuilder().sqlQuery("select count(*) as cnt from cp.`nation/nation.tbl` group by CURRENT_DATE").unOrdered().baselineColumns("cnt").baselineValues(25L).build().run();
        testBuilder().sqlQuery("select count(*) as cnt from cp.`tpch/nation.parquet` group by CURRENT_DATE").unOrdered().baselineColumns("cnt").baselineValues(25L).build().run();
        testBuilder().sqlQuery("select count(*) as cnt from cp.`employee.json` group by CURRENT_DATE").unOrdered().baselineColumns("cnt").baselineValues(1155L).build().run();
    }

    @Test
    public void test4443() throws Exception {
        run("SELECT MIN(columns[1]) FROM cp.`agg/4443.csv` GROUP BY columns[0]", new Object[0]);
    }

    @Test
    public void testCountStarRequired() throws Exception {
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.add(Pair.of(SchemaPath.getSimplePath("col"), TypeProtos.MajorType.newBuilder().setMinorType(TypeProtos.MinorType.BIGINT).setMode(TypeProtos.DataMode.REQUIRED).build()));
        testBuilder().sqlQuery("select count(*) as col from cp.`tpch/region.parquet`").schemaBaseLine(newArrayList).build().run();
        testBuilder().sqlQuery("select count(*) as col from cp.`tpch/region.parquet`").unOrdered().baselineColumns("col").baselineValues(5L).build().run();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testPushFilterDown() throws Exception {
        queryBuilder().sql("SELECT  cust.custAddress, \n       lineitem.provider \nFROM ( \n      SELECT cast(c_custkey AS bigint) AS custkey, \n             c_address                 AS custAddress \n      FROM   cp.`tpch/customer.parquet` ) cust \nLEFT JOIN \n  ( \n    SELECT DISTINCT l_linenumber, \n           CASE \n             WHEN l_partkey IN (1, 2) THEN 'Store1'\n             WHEN l_partkey IN (5, 6) THEN 'Store2'\n           END AS provider \n    FROM  cp.`tpch/lineitem.parquet` \n    WHERE ( l_orderkey >=20160101 AND l_partkey <=20160301) \n      AND   l_partkey IN (1,2, 5, 6) ) lineitem\nON        cust.custkey = lineitem.l_linenumber \nWHERE     provider IS NOT NULL \nGROUP BY  cust.custAddress, \n          lineitem.provider \nORDER BY  cust.custAddress, \n          lineitem.provider").planMatcher().include("(?s)(Join).*inner").exclude("(?s)(Join).*(left)").match();
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testCountComplexObjects() throws Exception {
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("COUNT_BIG_INT_REPEATED", "sia");
        newHashMap.put("COUNT_FLOAT_REPEATED", "sfa");
        newHashMap.put("COUNT_MAP_REPEATED", "soa");
        newHashMap.put("COUNT_MAP_REQUIRED", "oooi");
        newHashMap.put("COUNT_LIST_REPEATED", "odd");
        newHashMap.put("COUNT_LIST_OPTIONAL", "sia");
        for (String str : newHashMap.keySet()) {
            try {
                testBuilder().sqlQuery("select count(t.%s) %s from cp.`complex/json/complex.json` t", newHashMap.get(str), str).optionSettingQueriesForTestQuery(str.equals("COUNT_LIST_OPTIONAL") ? "alter session set `exec.enable_union_type`=true" : "").unOrdered().baselineColumns(str).baselineValues(3L).go();
                client.resetSession("exec.enable_union_type");
            } catch (Throwable th) {
                client.resetSession("exec.enable_union_type");
                throw th;
            }
        }
    }

    @Test
    @Category({UnlikelyTest.class})
    public void testCountOnFieldWithDots() throws Exception {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(dirTestWatcher.getRootDir(), "table.json")));
        try {
            bufferedWriter.write("{\"rk.q\": \"a\", \"m\": {\"a.b\":\"1\", \"a\":{\"b\":\"2\"}, \"c\":\"3\"}}");
            bufferedWriter.close();
            testBuilder().sqlQuery("select count(t.m.`a.b`) as a,\ncount(t.m.a.b) as b,\ncount(t.m['a.b']) as c,\ncount(t.rk.q) as d,\ncount(t.`rk.q`) as e\nfrom dfs.`%s` t", "table.json").unOrdered().baselineColumns("a", "b", "c", "d", "e").baselineValues(1L, 1L, 1L, 0L, 1L).go();
        } catch (Throwable th) {
            try {
                bufferedWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    public void testGroupByWithoutAggregate() throws Exception {
        try {
            run("select * from cp.`tpch/nation.parquet` group by n_regionkey", new Object[0]);
            Assert.fail("Exception was not thrown");
        } catch (UserRemoteException e) {
            Assert.assertTrue("No expected current \"Expression 'tpch/nation.parquet.**' is not being grouped\"", e.getMessage().matches(".*Expression 'tpch/nation\\.parquet\\.\\*\\*' is not being grouped(.*\\n*.*)"));
        }
    }

    @Test
    public void testCollectListStreamAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.HASHAGG.getOptionName(), false);
            testBuilder().sqlQuery("select collect_list('n_nationkey', n_nationkey, 'n_name', n_name, 'n_regionkey', n_regionkey, 'n_comment', n_comment) as l from (select * from cp.`tpch/nation.parquet` limit 2)").unOrdered().baselineColumns("l").baselineValues(TestBuilder.listOf(TestBuilder.mapOf("n_nationkey", 0, "n_name", "ALGERIA", "n_regionkey", 0, "n_comment", " haggle. carefully final deposits detect slyly agai"), TestBuilder.mapOf("n_nationkey", 1, "n_name", "ARGENTINA", "n_regionkey", 1, "n_comment", "al foxes promise slyly according to the regular accounts. bold requests alon"))).go();
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testCollectListHashAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.STREAMAGG.getOptionName(), false);
            testBuilder().sqlQuery("select collect_list('n_nationkey', n_nationkey, 'n_name', n_name, 'n_regionkey', n_regionkey, 'n_comment', n_comment) as l from (select * from cp.`tpch/nation.parquet` limit 2) group by 'a'").unOrdered().baselineColumns("l").baselineValues(TestBuilder.listOf(TestBuilder.mapOf("n_nationkey", 0, "n_name", "ALGERIA", "n_regionkey", 0, "n_comment", " haggle. carefully final deposits detect slyly agai"), TestBuilder.mapOf("n_nationkey", 1, "n_name", "ARGENTINA", "n_regionkey", 1, "n_comment", "al foxes promise slyly according to the regular accounts. bold requests alon"))).go();
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testCollectToListVarcharStreamAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.HASHAGG.getOptionName(), false);
            testBuilder().sqlQuery("select collect_to_list_varchar(`date`) as l from (select * from cp.`store/json/clicks.json` limit 2)").unOrdered().baselineColumns("l").baselineValues(TestBuilder.listOf("2014-04-26", "2014-04-20")).go();
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testCollectToListVarcharHashAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.STREAMAGG.getOptionName(), false);
            testBuilder().sqlQuery("select collect_to_list_varchar(`date`) as l from (select * from cp.`store/json/clicks.json` limit 2) group by 'a'").unOrdered().baselineColumns("l").baselineValues(TestBuilder.listOf("2014-04-26", "2014-04-20")).go();
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testSchemaFunctionStreamAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.HASHAGG.getOptionName(), false);
            testBuilder().sqlQuery("select schema('n_nationkey', n_nationkey, 'n_name', n_name, 'n_regionkey', n_regionkey, 'n_comment', n_comment) as l from (select * from cp.`tpch/nation.parquet` limit 2)").unOrdered().baselineColumns("l").baselineValues(new SchemaBuilder().add("n_nationkey", TypeProtos.MinorType.INT).add("n_name", TypeProtos.MinorType.VARCHAR).add("n_regionkey", TypeProtos.MinorType.INT).add("n_comment", TypeProtos.MinorType.VARCHAR).build().jsonString()).go();
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testSchemaFunctionHashAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.STREAMAGG.getOptionName(), false);
            testBuilder().sqlQuery("select schema('n_nationkey', n_nationkey, 'n_name', n_name, 'n_regionkey', n_regionkey, 'n_comment', n_comment) as l from (select * from cp.`tpch/nation.parquet` limit 2) group by 'a'").unOrdered().baselineColumns("l").baselineValues(new SchemaBuilder().add("n_nationkey", TypeProtos.MinorType.INT).add("n_name", TypeProtos.MinorType.VARCHAR).add("n_regionkey", TypeProtos.MinorType.INT).add("n_comment", TypeProtos.MinorType.VARCHAR).build().jsonString()).go();
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testMergeSchemaFunctionStreamAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.HASHAGG.getOptionName(), false);
            String jsonString = new SchemaBuilder().add("n_nationkey", TypeProtos.MinorType.INT).add("n_name", TypeProtos.MinorType.VARCHAR).add("n_regionkey", TypeProtos.MinorType.INT).add("n_comment", TypeProtos.MinorType.VARCHAR).build().jsonString();
            testBuilder().sqlQuery("select merge_schema('%s') as l from (select * from cp.`tpch/nation.parquet` limit 2)", jsonString).unOrdered().baselineColumns("l").baselineValues(jsonString).go();
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.HASHAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testMergeSchemaFunctionHashAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.STREAMAGG.getOptionName(), false);
            String jsonString = new SchemaBuilder().add("n_nationkey", TypeProtos.MinorType.INT).add("n_name", TypeProtos.MinorType.VARCHAR).add("n_regionkey", TypeProtos.MinorType.INT).add("n_comment", TypeProtos.MinorType.VARCHAR).build().jsonString();
            testBuilder().sqlQuery("select merge_schema('%s') as l from (select * from cp.`tpch/nation.parquet` limit 2) group by 'a'", jsonString).unOrdered().baselineColumns("l").baselineValues(jsonString).go();
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testInjectVariablesHashAgg() throws Exception {
        try {
            client.alterSession(PlannerSettings.STREAMAGG.getOptionName(), false);
            testBuilder().sqlQuery("select tdigest(p.col_int) from cp.`parquet/alltypes_required.parquet` p group by p.col_flt").unOrdered().expectsNumRecords(4).go();
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
        } catch (Throwable th) {
            client.resetSession(PlannerSettings.STREAMAGG.getOptionName());
            throw th;
        }
    }

    @Test
    public void testRowTypeMissMatch() throws Exception {
        testBuilder().sqlQuery("select col1, stddev(col2) as g1, SUM(col2) as g2 FROM (values ('UA', 3), ('USA', 2), ('UA', 3), ('USA', 5), ('USA', 1), ('UA', 9)) t(col1, col2) GROUP BY col1 order by col1").unOrdered().approximateEquality(1.0E-6d).baselineColumns("col1", "g1", "g2").baselineValues("UA", Double.valueOf(3.4641016151377544d), 15L).baselineValues("USA", Double.valueOf(2.0816659994661326d), 8L).go();
    }

    @Test
    public void testAggregateWithFilterCall() throws Exception {
        testBuilder().sqlQuery("SELECT count(n_name) FILTER(WHERE n_regionkey = 1) AS nations_count_in_1_region,count(n_name) FILTER(WHERE n_regionkey = 2) AS nations_count_in_2_region,count(n_name) FILTER(WHERE n_regionkey = 3) AS nations_count_in_3_region,count(n_name) FILTER(WHERE n_regionkey = 4) AS nations_count_in_4_region,count(n_name) FILTER(WHERE n_regionkey = 0) AS nations_count_in_0_region\nFROM cp.`tpch/nation.parquet`").unOrdered().baselineColumns("nations_count_in_1_region", "nations_count_in_2_region", "nations_count_in_3_region", "nations_count_in_4_region", "nations_count_in_0_region").baselineValues(5L, 5L, 5L, 5L, 5L).go();
    }
}
