package org.apache.druid.sql.calcite;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.AllGranularity;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.junit.Test;
import org.junit.runner.RunWith;

@RunWith(JUnitParamsRunner.class)
/* loaded from: input_file:org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.class */
public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest {
    @Test
    @Parameters(source = BaseCalciteQueryTest.QueryContextForJoinProvider.class)
    public void testCorrelatedSubquery(Map<String, Object> map) throws Exception {
        cannotVectorize();
        Map<String, Object> withLeftDirectAccessEnabled = withLeftDirectAccessEnabled(map);
        GroupByQuery.Builder builder = GroupByQuery.builder();
        TableDataSource tableDataSource = new TableDataSource(CalciteTests.USERVISITDATASOURCE);
        GroupByQuery.Builder dimensions = GroupByQuery.builder().setDataSource(GroupByQuery.builder().setDataSource(CalciteTests.USERVISITDATASOURCE).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setVirtualColumns(new VirtualColumn[]{new ExpressionVirtualColumn("v0", "timestamp_floor(\"__time\",'P1D',null,'UTC')", ColumnType.LONG, TestExprMacroTable.INSTANCE)}).setDimFilter(not(selector("country", null, null))).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("v0", "d0", ColumnType.LONG), new DefaultDimensionSpec("country", "d1")}).setAggregatorSpecs(new AggregatorFactory[]{new CardinalityAggregatorFactory("a0:a", (List) null, Collections.singletonList(new DefaultDimensionSpec("user", "user")), false, true)}).setPostAggregatorSpecs(Collections.singletonList(new HyperUniqueFinalizingPostAggregator("a0", "a0:a"))).setContext(withTimestampResultContext(withLeftDirectAccessEnabled, "d0", Granularities.DAY)).setGranularity(new AllGranularity()).build()).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("d1", "_d0")});
        AggregatorFactory[] aggregatorFactoryArr = new AggregatorFactory[2];
        aggregatorFactoryArr[0] = new LongSumAggregatorFactory("_a0:sum", "a0");
        aggregatorFactoryArr[1] = this.useDefault ? new CountAggregatorFactory("_a0:count") : new FilteredAggregatorFactory(new CountAggregatorFactory("_a0:count"), not(selector("a0", null, null)));
        testQuery("select country, ANY_VALUE(\n        select avg(\"users\") from (\n            select floor(__time to day), count(distinct user) \"users\" from visits f where f.country = visits.country group by 1\n        )\n     ) as \"DAU\"\nfrom visits \ngroup by 1", withLeftDirectAccessEnabled, (List<Query>) ImmutableList.of(builder.setDataSource(join(tableDataSource, new QueryDataSource(dimensions.setAggregatorSpecs(aggregatorFactoryArr).setPostAggregatorSpecs(Collections.singletonList(new ArithmeticPostAggregator("_a0", "quotient", Arrays.asList(new FieldAccessPostAggregator((String) null, "_a0:sum"), new FieldAccessPostAggregator((String) null, "_a0:count"))))).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), "j0.", equalsCondition(makeColumnExpression("country"), makeColumnExpression("j0._d0")), JoinType.LEFT)).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("country", "d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongAnyAggregatorFactory("a0", "j0._a0")}).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), (List<Object[]>) ImmutableList.of(new Object[]{"India", 2L}, new Object[]{"USA", 1L}, new Object[]{"canada", 3L}));
    }

    @Test
    @Parameters(source = BaseCalciteQueryTest.QueryContextForJoinProvider.class)
    public void testCorrelatedSubqueryWithLeftFilter(Map<String, Object> map) throws Exception {
        cannotVectorize();
        Map<String, Object> withLeftDirectAccessEnabled = withLeftDirectAccessEnabled(map);
        testQuery("select country, ANY_VALUE(\n        select max(\"users\") from (\n            select floor(__time to day), count(*) \"users\" from visits f where f.country = visits.country group by 1\n        )\n     ) as \"dailyVisits\"\nfrom visits \n where city = 'B' and __time between '2021-01-01 01:00:00' AND '2021-01-02 23:59:59' group by 1", withLeftDirectAccessEnabled, (List<Query>) ImmutableList.of(GroupByQuery.builder().setDataSource(join(new TableDataSource(CalciteTests.USERVISITDATASOURCE), new QueryDataSource(GroupByQuery.builder().setDataSource(GroupByQuery.builder().setDataSource(CalciteTests.USERVISITDATASOURCE).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setVirtualColumns(new VirtualColumn[]{new ExpressionVirtualColumn("v0", "timestamp_floor(\"__time\",'P1D',null,'UTC')", ColumnType.LONG, TestExprMacroTable.INSTANCE)}).setDimFilter(not(selector("country", null, null))).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("v0", "d0", ColumnType.LONG), new DefaultDimensionSpec("country", "d1")}).setAggregatorSpecs(new AggregatorFactory[]{new CountAggregatorFactory("a0")}).setContext(withTimestampResultContext(withLeftDirectAccessEnabled, "d0", Granularities.DAY)).setGranularity(new AllGranularity()).build()).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("d1", "_d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongMaxAggregatorFactory("_a0", "a0")}).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), "j0.", equalsCondition(makeColumnExpression("country"), makeColumnExpression("j0._d0")), JoinType.LEFT, selector("city", "B", null))).setQuerySegmentSpec(querySegmentSpec(Intervals.of("2021-01-01T01:00:00.000Z/2021-01-02T23:59:59.001Z"))).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("country", "d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongAnyAggregatorFactory("a0", "j0._a0")}).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), (List<Object[]>) ImmutableList.of(new Object[]{"canada", 4L}));
    }

    @Test
    @Parameters(source = BaseCalciteQueryTest.QueryContextForJoinProvider.class)
    public void testCorrelatedSubqueryWithLeftFilter_leftDirectAccessDisabled(Map<String, Object> map) throws Exception {
        cannotVectorize();
        testQuery("select country, ANY_VALUE(\n        select max(\"users\") from (\n            select floor(__time to day), count(*) \"users\" from visits f where f.country = visits.country group by 1\n        )\n     ) as \"dailyVisits\"\nfrom visits \n where city = 'B' and __time between '2021-01-01 01:00:00' AND '2021-01-02 23:59:59' group by 1", map, (List<Query>) ImmutableList.of(GroupByQuery.builder().setDataSource(join(new QueryDataSource(newScanQueryBuilder().dataSource(CalciteTests.USERVISITDATASOURCE).intervals(querySegmentSpec(Intervals.of("2021-01-01T01:00:00.000Z/2021-01-02T23:59:59.001Z"))).filters(selector("city", "B", null)).columns(new String[]{"__time", "city", "country"}).build()), new QueryDataSource(GroupByQuery.builder().setDataSource(GroupByQuery.builder().setDataSource(CalciteTests.USERVISITDATASOURCE).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setVirtualColumns(new VirtualColumn[]{new ExpressionVirtualColumn("v0", "timestamp_floor(\"__time\",'P1D',null,'UTC')", ColumnType.LONG, TestExprMacroTable.INSTANCE)}).setDimFilter(not(selector("country", null, null))).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("v0", "d0", ColumnType.LONG), new DefaultDimensionSpec("country", "d1")}).setAggregatorSpecs(new AggregatorFactory[]{new CountAggregatorFactory("a0")}).setContext(withTimestampResultContext(map, "d0", Granularities.DAY)).setGranularity(new AllGranularity()).build()).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("d1", "_d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongMaxAggregatorFactory("_a0", "a0")}).setGranularity(new AllGranularity()).setContext(map).build()), "j0.", equalsCondition(makeColumnExpression("country"), makeColumnExpression("j0._d0")), JoinType.LEFT)).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("country", "d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongAnyAggregatorFactory("a0", "j0._a0")}).setGranularity(new AllGranularity()).setContext(map).build()), (List<Object[]>) ImmutableList.of(new Object[]{"canada", 4L}));
    }

    @Test
    @Parameters(source = BaseCalciteQueryTest.QueryContextForJoinProvider.class)
    public void testCorrelatedSubqueryWithCorrelatedQueryFilter(Map<String, Object> map) throws Exception {
        cannotVectorize();
        Map<String, Object> withLeftDirectAccessEnabled = withLeftDirectAccessEnabled(map);
        testQuery("select country, ANY_VALUE(\n        select max(\"users\") from (\n            select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n        )\n     ) as \"dailyVisits\"\nfrom visits \n where city = 'B' group by 1", withLeftDirectAccessEnabled, (List<Query>) ImmutableList.of(GroupByQuery.builder().setDataSource(join(new TableDataSource(CalciteTests.USERVISITDATASOURCE), new QueryDataSource(GroupByQuery.builder().setDataSource(GroupByQuery.builder().setDataSource(CalciteTests.USERVISITDATASOURCE).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setVirtualColumns(new VirtualColumn[]{new ExpressionVirtualColumn("v0", "timestamp_floor(\"__time\",'P1D',null,'UTC')", ColumnType.LONG, TestExprMacroTable.INSTANCE)}).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("v0", "d0", ColumnType.LONG), new DefaultDimensionSpec("country", "d1")}).setAggregatorSpecs(new AggregatorFactory[]{new FilteredAggregatorFactory(new CountAggregatorFactory("a0"), not(selector("user", null, null)))}).setDimFilter(and(selector("city", "A", null), not(selector("country", null, null)))).setContext(withTimestampResultContext(withLeftDirectAccessEnabled, "d0", Granularities.DAY)).setGranularity(new AllGranularity()).build()).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("d1", "_d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongMaxAggregatorFactory("_a0", "a0")}).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), "j0.", equalsCondition(makeColumnExpression("country"), makeColumnExpression("j0._d0")), JoinType.LEFT, selector("city", "B", null))).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("country", "d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongAnyAggregatorFactory("a0", "j0._a0")}).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), (List<Object[]>) ImmutableList.of(new Object[]{"canada", 2L}));
    }

    @Test
    @Parameters(source = BaseCalciteQueryTest.QueryContextForJoinProvider.class)
    public void testCorrelatedSubqueryWithCorrelatedQueryFilter_Scan(Map<String, Object> map) throws Exception {
        cannotVectorize();
        Map<String, Object> withLeftDirectAccessEnabled = withLeftDirectAccessEnabled(map);
        testQuery("select country, ANY_VALUE(\n        select max(\"users\") from (\n            select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n        )\n     ) as \"dailyVisits\"\nfrom visits \n where city = 'B' group by 1", withLeftDirectAccessEnabled, (List<Query>) ImmutableList.of(GroupByQuery.builder().setDataSource(join(new TableDataSource(CalciteTests.USERVISITDATASOURCE), new QueryDataSource(GroupByQuery.builder().setDataSource(GroupByQuery.builder().setDataSource(CalciteTests.USERVISITDATASOURCE).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setVirtualColumns(new VirtualColumn[]{new ExpressionVirtualColumn("v0", "timestamp_floor(\"__time\",'P1D',null,'UTC')", ColumnType.LONG, TestExprMacroTable.INSTANCE)}).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("v0", "d0", ColumnType.LONG), new DefaultDimensionSpec("country", "d1")}).setAggregatorSpecs(new AggregatorFactory[]{new FilteredAggregatorFactory(new CountAggregatorFactory("a0"), not(selector("user", null, null)))}).setDimFilter(and(selector("city", "A", null), not(selector("country", null, null)))).setContext(withTimestampResultContext(withLeftDirectAccessEnabled, "d0", Granularities.DAY)).setGranularity(new AllGranularity()).build()).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("d1", "_d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongMaxAggregatorFactory("_a0", "a0")}).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), "j0.", equalsCondition(makeColumnExpression("country"), makeColumnExpression("j0._d0")), JoinType.LEFT, selector("city", "B", null))).setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)).setDimensions(new DimensionSpec[]{new DefaultDimensionSpec("country", "d0")}).setAggregatorSpecs(new AggregatorFactory[]{new LongAnyAggregatorFactory("a0", "j0._a0")}).setGranularity(new AllGranularity()).setContext(withLeftDirectAccessEnabled).build()), (List<Object[]>) ImmutableList.of(new Object[]{"canada", 2L}));
    }

    private Map<String, Object> withTimestampResultContext(Map<String, Object> map, String str, Granularity granularity) {
        HashMap hashMap = new HashMap(map);
        hashMap.put("timestampResultField", str);
        hashMap.put("timestampResultFieldGranularity", granularity);
        hashMap.put("timestampResultFieldInOriginalDimensions", 0);
        return hashMap;
    }
}
