/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapJoinRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.relnode.KapRel;
import org.apache.kylin.query.util.RuleUtils;

public class KapCountDistinctJoinRule
extends RelOptRule {
    public static final KapCountDistinctJoinRule INSTANCE_COUNT_DISTINCT_JOIN_ONESIDEAGG = new KapCountDistinctJoinRule(KapCountDistinctJoinRule.operand(KapAggregateRel.class, (RelOptRuleOperand)KapCountDistinctJoinRule.operand(KapJoinRel.class, (RelOptRuleOperandChildren)KapCountDistinctJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapCountDistinctJoinRule:agg(contain-count-distinct)-join-oneSideAgg");
    public static final KapCountDistinctJoinRule INSTANCE_COUNT_DISTINCT_AGG_PROJECT_JOIN = new KapCountDistinctJoinRule(KapCountDistinctJoinRule.operand(KapAggregateRel.class, (RelOptRuleOperand)KapCountDistinctJoinRule.operand(KapProjectRel.class, (RelOptRuleOperand)KapCountDistinctJoinRule.operand(KapJoinRel.class, (RelOptRuleOperandChildren)KapCountDistinctJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapCountDistinctJoinRule:agg(contain-count-distinct)-agg-project-join");

    public KapCountDistinctJoinRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall call) {
        KapAggregateRel aggregate = (KapAggregateRel)call.rel(0);
        KapJoinRel join = call.rel(1) instanceof KapJoinRel ? (KapJoinRel)call.rel(1) : (KapJoinRel)call.rel(2);
        return aggregate.isContainCountDistinct() && RuleUtils.isJoinOnlyOneAggChild(join);
    }

    public void onMatch(RelOptRuleCall call) {
        KapAggregateRel aggregate = (KapAggregateRel)call.rel(0);
        KapRel inputKapRel = (KapRel)call.rel(1);
        ImmutableList.Builder bottomAggCallsBuilder = ImmutableList.builder();
        ImmutableBitSet.Builder bottomGroupSetBuilder = ImmutableBitSet.builder();
        bottomGroupSetBuilder.addAll(aggregate.getGroupSet());
        for (AggregateCall agg : aggregate.getAggCallList()) {
            if (agg.getAggregation().getKind() == SqlKind.COUNT && agg.isDistinct()) {
                bottomGroupSetBuilder.addAll((Iterable)Lists.newArrayList((Iterable)agg.getArgList()));
                continue;
            }
            bottomAggCallsBuilder.add((Object)agg.copy((List)Lists.newArrayList((Iterable)agg.getArgList()), agg.filterArg));
        }
        ImmutableBitSet bottomGroupSetBuild = bottomGroupSetBuilder.build();
        ImmutableList bottomAggCallsBuild = bottomAggCallsBuilder.build();
        List bottomGroupSets = bottomGroupSetBuild.asList();
        Aggregate bottomAggregate = aggregate.copy(aggregate.getTraitSet(), (RelNode)inputKapRel, aggregate.indicator, bottomGroupSetBuild, null, (List)bottomAggCallsBuild);
        ImmutableBitSet.Builder topGroupSet = ImmutableBitSet.builder();
        ArrayList<Integer> topGroupSetList = new ArrayList<Integer>();
        this.setTopAggregateGroupSet(bottomAggregate, (Aggregate)aggregate, topGroupSetList, topGroupSet);
        int topAggArgsIndex = bottomGroupSets.size();
        ImmutableList.Builder topAggCalls = ImmutableList.builder();
        for (AggregateCall agg : aggregate.getAggCallList()) {
            if (agg.getAggregation().getKind() == SqlKind.COUNT && agg.isDistinct()) {
                ArrayList<Integer> aggArgsList = new ArrayList<Integer>();
                for (Integer arg : agg.getArgList()) {
                    aggArgsList.add(bottomGroupSets.indexOf(arg));
                }
                topAggCalls.add((Object)agg.copy(aggArgsList, agg.filterArg));
                continue;
            }
            if (agg.getAggregation().getKind() == SqlKind.COUNT) {
                topAggCalls.add((Object)AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM0, (boolean)false, (boolean)false, (List)Lists.newArrayList((Object[])new Integer[]{topAggArgsIndex++}), (int)-1, (RelDataType)agg.type, (String)agg.name));
                continue;
            }
            topAggCalls.add((Object)agg.copy((List)Lists.newArrayList((Object[])new Integer[]{topAggArgsIndex++}), agg.filterArg));
        }
        Aggregate topAggregate = aggregate.copy(aggregate.getTraitSet(), (RelNode)bottomAggregate, aggregate.indicator, topGroupSet.build(), null, (List)topAggCalls.build());
        call.transformTo((RelNode)topAggregate);
    }

    private void setTopAggregateGroupSet(Aggregate bottomAggregate, Aggregate aggregate, List<Integer> topGroupSetList, ImmutableBitSet.Builder topGroupSet) {
        List bottomAggregateGroupIndexList = bottomAggregate.getGroupSet().asList();
        List aggregateGroupIndexList = aggregate.getGroupSet().asList();
        for (int i = 0; i < bottomAggregateGroupIndexList.size(); ++i) {
            if (!aggregateGroupIndexList.contains(bottomAggregateGroupIndexList.get(i))) continue;
            topGroupSetList.add(i);
        }
        topGroupSet.addAll(topGroupSetList);
    }
}

