package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.job.shaded.com.google.common.collect.ImmutableList;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexInputRef;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlKind;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.query.exception.SumExprUnSupportException;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kyligence/kap/query/optrule/SumConstantConvertRule.class */
public class SumConstantConvertRule extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(SumConstantConvertRule.class);
    public static final SumConstantConvertRule INSTANCE = new SumConstantConvertRule(operand(KapAggregateRel.class, operand(KapProjectRel.class, null, kapProjectRel -> {
        return !AggExpressionUtil.hasAggInput(kapProjectRel);
    }, RelOptRule.any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "SumConstantConvertRule");

    public SumConstantConvertRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        try {
            boolean z = false;
            Iterator<AggExpressionUtil.AggExpression> it2 = AggExpressionUtil.collectSumExpressions((Aggregate) relOptRuleCall.rel(0), (Project) relOptRuleCall.rel(1)).iterator();
            while (it2.hasNext()) {
                if (it2.next().isSumConst()) {
                    z = true;
                }
            }
            return z;
        } catch (SumExprUnSupportException e) {
            logger.trace("Current rel unable to apply SumBasicOperatorRule", e);
            return false;
        }
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        try {
            RelBuilder builder = relOptRuleCall.builder();
            Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
            Project project = (Project) relOptRuleCall.rel(1);
            ContextUtil.dumpCalcitePlan("old plan", aggregate, logger);
            List<AggExpressionUtil.AggExpression> collectSumExpressions = AggExpressionUtil.collectSumExpressions(aggregate, project);
            Pair<List<AggExpressionUtil.GroupExpression>, ImmutableList<ImmutableBitSet>> collectGroupExprAndGroup = AggExpressionUtil.collectGroupExprAndGroup(aggregate, project);
            List<AggExpressionUtil.GroupExpression> first = collectGroupExprAndGroup.getFirst();
            ImmutableList<ImmutableBitSet> second = collectGroupExprAndGroup.getSecond();
            builder.push(project.getInput());
            builder.project(buildBottomProject(builder, project, first, collectSumExpressions));
            ImmutableBitSet.Builder builder2 = ImmutableBitSet.builder();
            for (AggExpressionUtil.GroupExpression groupExpression : first) {
                for (int i = 0; i < groupExpression.getBottomAggInput().length; i++) {
                    builder2.set(groupExpression.getBottomAggInput()[i]);
                }
            }
            ImmutableBitSet build = builder2.build();
            builder.aggregate(builder.groupKey(build, null), buildBottomAggCall(builder, collectSumExpressions, build.cardinality()));
            for (AggExpressionUtil.GroupExpression groupExpression2 : first) {
                for (int i2 = 0; i2 < groupExpression2.getTopProjInput().length; i2++) {
                    groupExpression2.getTopProjInput()[i2] = build.indexOf(groupExpression2.getBottomAggInput()[i2]);
                }
            }
            builder.project(buildTopProject(builder, project, first, collectSumExpressions));
            ImmutableBitSet.Builder builder3 = ImmutableBitSet.builder();
            for (int i3 = 0; i3 < first.size(); i3++) {
                builder3.set(i3);
            }
            ImmutableBitSet build2 = builder3.build();
            builder.aggregate(builder.groupKey(build2, second), buildTopAggregate(aggregate.getAggCallList(), build2.cardinality(), collectSumExpressions));
            RelNode build3 = builder.build();
            ContextUtil.dumpCalcitePlan("new plan", build3, logger);
            relOptRuleCall.transformTo(build3);
        } catch (Exception e) {
            logger.error("sql cannot apply sum constant rule ", e);
        }
    }

    private List<RexNode> buildBottomProject(RelBuilder relBuilder, Project project, List<AggExpressionUtil.GroupExpression> list, List<AggExpressionUtil.AggExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (AggExpressionUtil.GroupExpression groupExpression : list) {
            int[] bottomProjInput = groupExpression.getBottomProjInput();
            for (int i = 0; i < bottomProjInput.length; i++) {
                groupExpression.getBottomAggInput()[i] = newArrayList.size();
                newArrayList.add(relBuilder.getRexBuilder().makeInputRef(project.getInput(), bottomProjInput[i]));
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression : list2) {
            if (!aggExpression.isSumConst() && aggExpression.getExpression() != null) {
                aggExpression.getBottomAggInput()[0] = newArrayList.size();
                newArrayList.add(aggExpression.getExpression());
            }
        }
        return newArrayList;
    }

    private List<AggregateCall> buildBottomAggCall(RelBuilder relBuilder, List<AggExpressionUtil.AggExpression> list, int i) {
        AggregateCall copy;
        ArrayList newArrayList = Lists.newArrayList();
        int i2 = 0;
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            int i3 = i2;
            i2++;
            String str = "SUM_CONST$" + i3;
            if (aggExpression.isSumConst()) {
                copy = AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, Lists.newArrayList(), -1, i, relBuilder.peek(), null, str);
            } else {
                AggregateCall aggCall = aggExpression.getAggCall();
                copy = aggCall.copy((List) Arrays.stream(aggExpression.getBottomAggInput()).boxed().collect(Collectors.toList()), aggCall.filterArg);
            }
            aggExpression.getTopProjInput()[0] = newArrayList.size() + i;
            newArrayList.add(copy);
        }
        return newArrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v31, types: [org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode] */
    /* JADX WARN: Type inference failed for: r9v0, types: [org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder] */
    private List<RexNode> buildTopProject(RelBuilder relBuilder, Project project, List<AggExpressionUtil.GroupExpression> list, List<AggExpressionUtil.AggExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (AggExpressionUtil.GroupExpression groupExpression : list) {
            newArrayList.add(relBuilder.getRexBuilder().ensureType(groupExpression.getExpression().getType(), (RexNode) groupExpression.getExpression().accept(new RelOptUtil.RexInputConverter(relBuilder.getRexBuilder(), project.getInput().getRowType().getFieldList(), relBuilder.peek().getRowType().getFieldList(), AggExpressionUtil.generateAdjustments(groupExpression.getBottomProjInput(), groupExpression.getTopProjInput()))), false));
        }
        for (AggExpressionUtil.AggExpression aggExpression : list2) {
            RexInputRef makeInputRef = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjInput()[0]);
            if (aggExpression.isSumConst()) {
                RexNode expression = aggExpression.getExpression();
                ArrayList newArrayList2 = Lists.newArrayList();
                newArrayList2.add(expression);
                newArrayList2.add(makeInputRef);
                makeInputRef = relBuilder.getRexBuilder().ensureType(aggExpression.getAggCall().getType(), relBuilder.call(SqlStdOperatorTable.MULTIPLY, newArrayList2), false);
            }
            newArrayList.add(makeInputRef);
        }
        return newArrayList;
    }

    private List<AggregateCall> buildTopAggregate(List<AggregateCall> list, int i, List<AggExpressionUtil.AggExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            AggregateCall aggCall = list2.get(i2).getAggCall();
            newArrayList.add(AggregateCall.create(SqlKind.COUNT == aggCall.getAggregation().getKind() ? SqlStdOperatorTable.SUM0 : aggCall.getAggregation(), false, false, (List<Integer>) Lists.newArrayList(Integer.valueOf(i + i2)), -1, aggCall.getType(), "TOP_AGG$" + i2));
        }
        return newArrayList;
    }
}
