/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;

public class SumCaseWhenFunctionRule
extends AbstractAggCaseWhenFunctionRule {
    public static final SumCaseWhenFunctionRule INSTANCE = new SumCaseWhenFunctionRule(SumCaseWhenFunctionRule.operand(KapAggregateRel.class, (RelOptRuleOperand)SumCaseWhenFunctionRule.operand(KapProjectRel.class, null, input -> !AggExpressionUtil.hasAggInput((RelNode)input), (RelOptRuleOperandChildren)RelOptRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "SumCaseWhenFunctionRule");

    public SumCaseWhenFunctionRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    private boolean isSumCaseExpr(AggregateCall aggregateCall, Project inputProject) {
        if (aggregateCall.getArgList().size() != 1) {
            return false;
        }
        int input = (Integer)aggregateCall.getArgList().get(0);
        RexNode expression = (RexNode)inputProject.getChildExps().get(input);
        return AggExpressionUtil.hasSumCaseWhen(aggregateCall, expression);
    }

    @Override
    protected boolean checkAggCaseExpression(Aggregate oldAgg, Project oldProject) {
        for (AggregateCall call : oldAgg.getAggCallList()) {
            if (!this.isSumCaseExpr(call, oldProject)) continue;
            return true;
        }
        return false;
    }

    @Override
    protected boolean isApplicableWithSumCaseRule(AggregateCall aggregateCall, Project project) {
        SqlKind aggFunction = aggregateCall.getAggregation().getKind();
        return aggFunction == SqlKind.SUM || aggFunction == SqlKind.SUM0 || aggFunction == SqlKind.MAX || aggFunction == SqlKind.MIN || aggFunction == SqlKind.COUNT && !aggregateCall.isDistinct() || "BITMAP_UUID".equalsIgnoreCase(aggregateCall.getName());
    }

    @Override
    protected boolean isApplicableAggExpression(AggExpressionUtil.AggExpression aggExpr) {
        return aggExpr.isSumCase();
    }

    @Override
    protected SqlAggFunction getBottomAggFunc(AggregateCall aggCall) {
        return SqlStdOperatorTable.SUM;
    }

    @Override
    protected SqlAggFunction getTopAggFunc(AggregateCall aggCall) {
        return SqlKind.COUNT == aggCall.getAggregation().getKind() ? SqlStdOperatorTable.SUM0 : aggCall.getAggregation();
    }

    @Override
    protected String getBottomAggPrefix() {
        return "SUM_CASE$";
    }
}

