package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.job.shaded.com.google.common.collect.ImmutableList;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexInputRef;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexLiteral;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlKind;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.query.exception.SumExprUnSupportException;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kyligence/kap/query/optrule/SumBasicOperatorRule.class */
public class SumBasicOperatorRule extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(SumBasicOperatorRule.class);
    public static final SumBasicOperatorRule INSTANCE = new SumBasicOperatorRule(operand(KapAggregateRel.class, operand(KapProjectRel.class, null, kapProjectRel -> {
        return !AggExpressionUtil.hasAggInput(kapProjectRel);
    }, RelOptRule.any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "SumBasicOperatorRule");

    public SumBasicOperatorRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        try {
            boolean z = false;
            Iterator<AggExpressionUtil.AggExpression> it2 = AggExpressionUtil.collectSumExpressions((Aggregate) relOptRuleCall.rel(0), (Project) relOptRuleCall.rel(1)).iterator();
            while (it2.hasNext()) {
                if (checkExpressionSupported(it2.next())) {
                    z = true;
                }
            }
            return z;
        } catch (SumExprUnSupportException e) {
            logger.trace("Current rel unable to apply SumBasicOperatorRule", e);
            return false;
        }
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        try {
            Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
            Project project = (Project) relOptRuleCall.rel(1);
            RelBuilder builder = relOptRuleCall.builder();
            builder.push(project.getInput());
            ContextUtil.dumpCalcitePlan("old plan", aggregate, logger);
            List<AggExpressionUtil.AggExpression> collectSumExpressions = AggExpressionUtil.collectSumExpressions(aggregate, project);
            Pair<List<AggExpressionUtil.GroupExpression>, ImmutableList<ImmutableBitSet>> collectGroupExprAndGroup = AggExpressionUtil.collectGroupExprAndGroup(aggregate, project);
            List<AggExpressionUtil.GroupExpression> first = collectGroupExprAndGroup.getFirst();
            ImmutableList<ImmutableBitSet> second = collectGroupExprAndGroup.getSecond();
            builder.project(buildBottomProject(builder, project, first, collectSumExpressions));
            ImmutableBitSet.Builder builder2 = ImmutableBitSet.builder();
            for (AggExpressionUtil.GroupExpression groupExpression : first) {
                for (int i = 0; i < groupExpression.getBottomAggInput().length; i++) {
                    builder2.set(groupExpression.getBottomAggInput()[i]);
                }
            }
            ImmutableBitSet build = builder2.build();
            builder.aggregate(builder.groupKey(build, null), buildBottomAggregate(builder, collectSumExpressions, build.cardinality()));
            for (AggExpressionUtil.GroupExpression groupExpression2 : first) {
                for (int i2 = 0; i2 < groupExpression2.getTopProjInput().length; i2++) {
                    groupExpression2.getTopProjInput()[i2] = build.indexOf(groupExpression2.getBottomAggInput()[i2]);
                }
            }
            builder.project(buildTopProjectList(builder, project, collectSumExpressions, first));
            ImmutableBitSet.Builder builder3 = ImmutableBitSet.builder();
            for (int i3 = 0; i3 < first.size(); i3++) {
                builder3.set(i3);
            }
            ImmutableBitSet build2 = builder3.build();
            builder.aggregate(builder.groupKey(build2, second), buildTopAggregate(aggregate.getAggCallList(), build2.cardinality(), collectSumExpressions));
            RelNode build3 = builder.build();
            ContextUtil.dumpCalcitePlan("new plan", build3, logger);
            relOptRuleCall.transformTo(build3);
        } catch (Exception e) {
            logger.error("sql cannot apply sum multiply rule ", e);
        }
    }

    private List<RexNode> buildBottomProject(RelBuilder relBuilder, Project project, List<AggExpressionUtil.GroupExpression> list, List<AggExpressionUtil.AggExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (AggExpressionUtil.GroupExpression groupExpression : list) {
            int[] bottomProjInput = groupExpression.getBottomProjInput();
            for (int i = 0; i < bottomProjInput.length; i++) {
                groupExpression.getBottomAggInput()[i] = newArrayList.size();
                newArrayList.add(relBuilder.getRexBuilder().makeInputRef(project.getInput(), bottomProjInput[i]));
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression : list2) {
            if (checkExpressionSupported(aggExpression)) {
                List list3 = (List) Arrays.stream(aggExpression.getBottomProjInput()).mapToObj(i2 -> {
                    return relBuilder.getRexBuilder().makeInputRef(project.getInput(), i2);
                }).collect(Collectors.toList());
                if (aggExpression.getBottomAggInput().length != 0) {
                    aggExpression.getBottomAggInput()[0] = newArrayList.size();
                }
                newArrayList.addAll(list3);
            } else if (aggExpression.getExpression() != null) {
                aggExpression.getBottomAggInput()[0] = newArrayList.size();
                newArrayList.add(aggExpression.getExpression());
            }
        }
        return newArrayList;
    }

    private List<AggregateCall> buildBottomAggregate(RelBuilder relBuilder, List<AggExpressionUtil.AggExpression> list, int i) {
        AggregateCall copy;
        int i2 = 0;
        ArrayList newArrayList = Lists.newArrayList();
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            if (checkExpressionSupported(aggExpression)) {
                AggExpressionUtil.assertCondition(aggExpression.getBottomProjInput().length == 1, "SumBasicOperatorRule only handles aggregation of single source column");
                int i3 = i2;
                i2++;
                copy = AggregateCall.create(SqlStdOperatorTable.SUM, false, false, Lists.newArrayList(Integer.valueOf(aggExpression.getBottomAggInput()[0])), -1, i, relBuilder.peek(), null, "SUM_OP$" + i3);
            } else {
                AggregateCall aggCall = aggExpression.getAggCall();
                copy = aggCall.copy((List) Arrays.stream(aggExpression.getBottomAggInput()).boxed().collect(Collectors.toList()), aggCall.filterArg);
            }
            aggExpression.getTopProjInput()[0] = i + newArrayList.size();
            newArrayList.add(copy);
        }
        return newArrayList;
    }

    private List<RexNode> buildTopProjectList(RelBuilder relBuilder, Project project, List<AggExpressionUtil.AggExpression> list, List<AggExpressionUtil.GroupExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (AggExpressionUtil.GroupExpression groupExpression : list2) {
            newArrayList.add(relBuilder.getRexBuilder().ensureType(groupExpression.getExpression().getType(), (RexNode) groupExpression.getExpression().accept(new RelOptUtil.RexInputConverter(relBuilder.getRexBuilder(), project.getInput().getRowType().getFieldList(), relBuilder.peek().getRowType().getFieldList(), AggExpressionUtil.generateAdjustments(groupExpression.getBottomProjInput(), groupExpression.getTopProjInput()))), false));
        }
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            RexInputRef makeInputRef = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjInput()[0]);
            if (checkExpressionSupported(aggExpression)) {
                makeInputRef = relBuilder.getRexBuilder().ensureType(aggExpression.getAggCall().getType(), (RexNode) aggExpression.getExpression().accept(new RelOptUtil.RexInputConverter(relBuilder.getRexBuilder(), project.getInput().getRowType().getFieldList(), relBuilder.peek().getRowType().getFieldList(), AggExpressionUtil.generateAdjustments(aggExpression.getBottomProjInput(), aggExpression.getTopProjInput()))), false);
            }
            newArrayList.add(makeInputRef);
        }
        return newArrayList;
    }

    private List<AggregateCall> buildTopAggregate(List<AggregateCall> list, int i, List<AggExpressionUtil.AggExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            AggregateCall aggCall = list2.get(i2).getAggCall();
            newArrayList.add(AggregateCall.create(SqlKind.COUNT == aggCall.getAggregation().getKind() ? SqlStdOperatorTable.SUM : aggCall.getAggregation(), false, false, (List<Integer>) Lists.newArrayList(Integer.valueOf(i + i2)), -1, aggCall.getType(), "AGG$" + i2));
        }
        return newArrayList;
    }

    private boolean checkExpressionSupported(AggExpressionUtil.AggExpression aggExpression) {
        if (aggExpression.isSumCase()) {
            throw new SumExprUnSupportException("SumBasicOperatorRule is unable to handle sum case expression.");
        }
        AggregateCall aggCall = aggExpression.getAggCall();
        RexNode expression = aggExpression.getExpression();
        if (!AggExpressionUtil.isSum(aggCall.getAggregation().getKind()) || !(expression instanceof RexCall) || !isBasicOperand(expression)) {
            return false;
        }
        checkUnSupportOperands(expression);
        return true;
    }

    private boolean isBasicOperand(RexNode rexNode) {
        if ((rexNode instanceof RexLiteral) || (rexNode instanceof RexInputRef)) {
            return true;
        }
        if ((SqlKind.PLUS != rexNode.getKind() && SqlKind.MINUS != rexNode.getKind() && SqlKind.TIMES != rexNode.getKind() && !KapRuleUtils.isDivide(rexNode)) || !(rexNode instanceof RexCall)) {
            return false;
        }
        RexCall rexCall = (RexCall) rexNode;
        return isBasicOperand(rexCall.getOperands().get(0)) && isBasicOperand(rexCall.getOperands().get(1));
    }

    private void checkUnSupportOperands(RexNode rexNode) {
        if (rexNode instanceof RexCall) {
            RexCall rexCall = (RexCall) rexNode;
            verify(rexCall);
            for (RexNode rexNode2 : rexCall.getOperands()) {
                if (rexNode2 instanceof RexCall) {
                    checkUnSupportOperands(rexNode2);
                }
            }
        }
    }

    private void verify(RexCall rexCall) {
        switch (rexCall.getKind()) {
            case PLUS:
            case MINUS:
                verifyPlusOrMinus(rexCall);
                break;
            case TIMES:
                verifyMultiply(rexCall);
                break;
            case DIVIDE:
                verifyDivide(rexCall);
                break;
        }
        if (KapRuleUtils.isDivide(rexCall)) {
            verifyDivide(rexCall);
        }
    }

    private void verifyPlusOrMinus(RexCall rexCall) {
        throw new SumExprUnSupportException("That PLUS/MINUS of the columns is not supported for sum expression");
    }

    private void verifyMultiply(RexCall rexCall) {
        RexNode rexNode = rexCall.getOperands().get(0);
        RexNode rexNode2 = rexCall.getOperands().get(1);
        if (!isConstant(rexNode) && !isConstant(rexNode2)) {
            throw new SumExprUnSupportException("That both of the two sides of the columns is not supported for " + rexCall.getKind().toString());
        }
    }

    private void verifyDivide(RexCall rexCall) {
        if (!isConstant(rexCall.getOperands().get(1))) {
            throw new SumExprUnSupportException("That the right side of the columns is not supported for " + rexCall.getKind().toString());
        }
    }

    private boolean isConstant(RexNode rexNode) {
        return extractColumn(rexNode).isEmpty();
    }

    private List<RexNode> extractColumn(RexNode rexNode) {
        ArrayList newArrayList = Lists.newArrayList();
        if (rexNode instanceof RexInputRef) {
            newArrayList.add(rexNode);
        }
        if (rexNode instanceof RexCall) {
            Iterator<RexNode> it2 = ((RexCall) rexNode).getOperands().iterator();
            while (it2.hasNext()) {
                newArrayList.addAll(extractColumn(it2.next()));
            }
        }
        return newArrayList;
    }
}
