package io.kyligence.kap.query.optrule;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataType;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.BasicSqlType;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.metadata.datatype.DataType;
import org.apache.kylin.query.calcite.KylinRelDataTypeSystem;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.apache.kylin.query.util.RuleUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kyligence/kap/query/optrule/KapSumCastTransposeRule.class */
public class KapSumCastTransposeRule extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(KapSumCastTransposeRule.class);
    public static final KapSumCastTransposeRule INSTANCE = new KapSumCastTransposeRule(operand(KapAggregateRel.class, operand(KapProjectRel.class, null, (v0) -> {
        return needSumCastTranspose(v0);
    }, any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapSumTransCastToThenRule");

    public KapSumCastTransposeRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    public static boolean needSumCastTranspose(Project project) {
        if ((project.getInput() instanceof HepRelVertex) && (((HepRelVertex) project.getInput()).getCurrentRel() instanceof KapAggregateRel)) {
            return false;
        }
        Iterator<RexNode> it2 = project.getChildExps().iterator();
        while (it2.hasNext()) {
            if (RuleUtils.containCast(it2.next())) {
                return true;
            }
        }
        return false;
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        Project project = (Project) relOptRuleCall.rel(1);
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (AggExpressionUtil.isSum(aggregateCall.getAggregation().kind)) {
                RexNode rexNode = project.getProjects().get(aggregateCall.getArgList().get(0).intValue());
                if (RuleUtils.containCast(rexNode)) {
                    DataType type = DataType.getType(((RexCall) rexNode).getOperands().get(0).getType().getSqlTypeName().getName());
                    return type.isNumberFamily() || type.isIntegerFamily();
                }
            }
        }
        return false;
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        try {
            RelNode transposeSumCast = transposeSumCast(relOptRuleCall.builder(), (Aggregate) relOptRuleCall.rel(0), (Project) relOptRuleCall.rel(1));
            ContextUtil.dumpCalcitePlan("new plan", transposeSumCast, logger);
            relOptRuleCall.transformTo(transposeSumCast);
        } catch (Exception e) {
            logger.error("sql cannot apply sum cast transpose rule ", e);
        }
    }

    private RelNode transposeSumCast(RelBuilder relBuilder, Aggregate aggregate, Project project) {
        relBuilder.push(project.getInput());
        List<AggExpressionUtil.AggExpression> list = (List) aggregate.getAggCallList().stream().map(AggExpressionUtil.AggExpression::new).collect(Collectors.toList());
        relBuilder.project(buildBottomProject(project, list));
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        relBuilder.aggregate(relBuilder.groupKey(groupSet, null), buildBottomAggregate(relBuilder, list, groupSet.cardinality()));
        relBuilder.project(buildTopProject(relBuilder, project, aggregate, list));
        return relBuilder.build();
    }

    private List<RexNode> buildBottomProject(Project project, List<AggExpressionUtil.AggExpression> list) {
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.addAll(project.getChildExps());
        SqlTypeFactoryImpl sqlTypeFactoryImpl = new SqlTypeFactoryImpl(new KylinRelDataTypeSystem());
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            AggregateCall aggCall = aggExpression.getAggCall();
            if (AggExpressionUtil.isSum(aggCall.getAggregation().kind)) {
                int intValue = aggCall.getArgList().get(0).intValue();
                RexNode rexNode = project.getProjects().get(intValue);
                if (RuleUtils.containCast(rexNode)) {
                    newArrayList.set(intValue, ((RexCall) rexNode).operands.get(0));
                    RelDataType type = ((RexCall) rexNode).operands.get(0).getType();
                    if ((type instanceof BasicSqlType) && SqlTypeName.INTEGER == type.getSqlTypeName()) {
                        type = sqlTypeFactoryImpl.createTypeWithNullability(sqlTypeFactoryImpl.createSqlType(SqlTypeName.BIGINT), type.isNullable());
                    }
                    aggExpression.setType(type);
                }
            }
        }
        return newArrayList;
    }

    private List<AggregateCall> buildBottomAggregate(RelBuilder relBuilder, List<AggExpressionUtil.AggExpression> list, int i) {
        ArrayList newArrayList = Lists.newArrayList();
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            if (AggExpressionUtil.isSum(aggExpression.getAggCall().getAggregation().kind)) {
                newArrayList.add(AggregateCall.create(SqlStdOperatorTable.SUM, false, false, aggExpression.getAggCall().getArgList(), -1, i, relBuilder.peek(), aggExpression.getType(), aggExpression.getAggCall().name));
            } else {
                newArrayList.add(aggExpression.getAggCall());
            }
        }
        return newArrayList;
    }

    private List<RexNode> buildTopProject(RelBuilder relBuilder, Project project, Aggregate aggregate, List<AggExpressionUtil.AggExpression> list) {
        ArrayList newArrayList = Lists.newArrayList();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        int i = 0;
        int size = aggregate.getGroupSet().asSet().size();
        while (i < size) {
            newArrayList.add(relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i));
            i++;
        }
        Iterator<AggExpressionUtil.AggExpression> it2 = list.iterator();
        while (it2.hasNext()) {
            AggregateCall aggCall = it2.next().getAggCall();
            if (AggExpressionUtil.isSum(aggCall.getAggregation().kind)) {
                RexNode rexNode = project.getProjects().get(aggCall.getArgList().get(0).intValue());
                if (RuleUtils.containCast(rexNode)) {
                    RelDataType relDataType = ((RexCall) rexNode).type;
                    if ((relDataType instanceof BasicSqlType) && relDataType.getPrecision() < aggCall.getType().getPrecision()) {
                        relDataType = aggCall.getType();
                    }
                    newArrayList.add(relBuilder.getRexBuilder().makeCast(relDataType, relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i)));
                } else if (RuleUtils.isNotNullLiteral(rexNode)) {
                    newArrayList.add(relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i));
                } else {
                    newArrayList.add(rexBuilder.makeBigintLiteral(BigDecimal.ZERO));
                }
            } else {
                newArrayList.add(relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i));
            }
            i++;
        }
        return newArrayList;
    }
}
