/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.type.BasicSqlType;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KapAggSumCastRule
extends RelOptRule {
    public static final KapAggSumCastRule INSTANCE = new KapAggSumCastRule(KapAggSumCastRule.operand(KapAggregateRel.class, (RelOptRuleOperand)KapAggSumCastRule.operand(KapProjectRel.class, null, input -> !AggExpressionUtil.hasAggInput((RelNode)input), (RelOptRuleOperandChildren)RelOptRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapAggSumCastRule");
    private static final Logger logger = LoggerFactory.getLogger(KapAggSumCastRule.class);

    public KapAggSumCastRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall ruleCall) {
        return true;
    }

    public void onMatch(RelOptRuleCall ruleCall) {
        HashMap sumMatchMap = new HashMap();
        HashMap<AggregateCall, AggregateCall> rewriteAggCallMap = new HashMap<AggregateCall, AggregateCall>();
        Aggregate oldAgg = (Aggregate)ruleCall.rel(0);
        Project oldProject = (Project)ruleCall.rel(1);
        List aggCallList = oldAgg.getAggCallList();
        boolean hasAggSum = false;
        for (AggregateCall aggregateCall : aggCallList) {
            if (!SqlKind.SUM.name().equalsIgnoreCase(aggregateCall.getAggregation().getKind().name())) continue;
            hasAggSum = true;
            List argList = aggregateCall.getArgList();
            if (argList.size() != 1) continue;
            sumMatchMap.put(argList.get(0), aggregateCall);
        }
        if (!hasAggSum) {
            return;
        }
        boolean isHasAggSumCastDouble = false;
        SqlTypeFactoryImpl sqlTypeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
        LinkedList<Object> bottomProjectRexNodes = new LinkedList<Object>();
        LinkedList<RexNode> rewriteProjectRexNodes = new LinkedList<RexNode>();
        List exprList = oldProject.getChildExps();
        Set groupBySet = oldAgg.getGroupSet().asSet();
        for (int i = 0; i < exprList.size(); ++i) {
            AggregateCall aggregateCall = (AggregateCall)sumMatchMap.get(i);
            RexNode rexNode = (RexNode)exprList.get(i);
            if (aggregateCall == null) {
                bottomProjectRexNodes.add(rexNode);
                continue;
            }
            RexNode curProjectExp = rexNode;
            if (rexNode instanceof RexCall && ((RexCall)rexNode).op instanceof SqlCastFunction) {
                RexCall rexCall = (RexCall)rexNode;
                List opList = rexCall.getOperands();
                if (opList.size() != 1) {
                    bottomProjectRexNodes.add(rexNode);
                    continue;
                }
                RexNode rexNodeOp = (RexNode)opList.get(0);
                if (SqlTypeName.DOUBLE == rexCall.getType().getSqlTypeName() && SqlTypeFamily.NUMERIC == rexNodeOp.getType().getSqlTypeName().getFamily()) {
                    AggregateCall newAggCall;
                    isHasAggSumCastDouble = true;
                    List operands = ((RexCall)curProjectExp).getOperands();
                    RexNode curRexNode = (RexNode)operands.get(0);
                    RelDataType returnDataType = curRexNode.getType();
                    if (SqlTypeName.INTEGER == curRexNode.getType().getSqlTypeName() || SqlTypeName.SMALLINT == curRexNode.getType().getSqlTypeName() || SqlTypeName.TINYINT == curRexNode.getType().getSqlTypeName()) {
                        returnDataType = new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT);
                        returnDataType = sqlTypeFactory.createTypeWithNullability(returnDataType, true);
                    }
                    if (groupBySet.contains(i)) {
                        newAggCall = new AggregateCall(aggregateCall.getAggregation(), false, Arrays.asList(exprList.size() + rewriteProjectRexNodes.size()), returnDataType, aggregateCall.getName());
                        rewriteProjectRexNodes.add(curRexNode);
                    } else {
                        newAggCall = new AggregateCall(aggregateCall.getAggregation(), false, aggregateCall.getArgList(), returnDataType, aggregateCall.getName());
                        curProjectExp = curRexNode;
                    }
                    rewriteAggCallMap.put(aggregateCall, newAggCall);
                }
            }
            bottomProjectRexNodes.add(curProjectExp);
        }
        if (!isHasAggSumCastDouble) {
            return;
        }
        bottomProjectRexNodes.addAll(rewriteProjectRexNodes);
        RelBuilder relBuilder = ruleCall.builder();
        relBuilder.push(oldProject.getInput());
        relBuilder.project(bottomProjectRexNodes);
        ArrayList newAggregateCallList = new ArrayList(oldAgg.getAggCallList().size());
        oldAgg.getAggCallList().forEach(aggCall -> {
            AggregateCall newAggCall = (AggregateCall)rewriteAggCallMap.get(aggCall);
            if (newAggCall != null) {
                newAggregateCallList.add(newAggCall);
            } else {
                newAggregateCallList.add(aggCall);
            }
        });
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(oldAgg.getGroupSet(), oldAgg.getGroupSets());
        relBuilder.aggregate(groupKey, newAggregateCallList);
        List<RexNode> topProjList = this.buildTopProject(relBuilder, oldAgg, rewriteAggCallMap);
        relBuilder.project(topProjList);
        ruleCall.transformTo(relBuilder.build());
    }

    private List<RexNode> buildTopProject(RelBuilder relBuilder, Aggregate oldAgg, Map<AggregateCall, AggregateCall> rewriteAggCallMap) {
        ArrayList topProjectList = Lists.newArrayList();
        int groupSize = oldAgg.getGroupSet().asSet().size();
        for (int i = 0; i < groupSize; ++i) {
            topProjectList.add(relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i));
        }
        for (AggregateCall aggCall : oldAgg.getAggCallList()) {
            RexInputRef value;
            AggregateCall rewriteAggCall = rewriteAggCallMap.get(aggCall);
            int projectIndex = topProjectList.size();
            if (rewriteAggCall != null) {
                RelDataType type = aggCall.getType();
                value = relBuilder.getRexBuilder().makeCast(type, (RexNode)relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), projectIndex));
            } else {
                value = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), projectIndex);
            }
            topProjectList.add(value);
        }
        return topProjectList;
    }
}

