package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataType;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexInputRef;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexLiteral;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlKind;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlCaseOperator;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kyligence/kap/query/optrule/KapSumTransCastToThenRule.class */
public class KapSumTransCastToThenRule extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(KapSumTransCastToThenRule.class);
    public static final KapSumTransCastToThenRule INSTANCE = new KapSumTransCastToThenRule(operand(KapAggregateRel.class, operand(KapProjectRel.class, null, (v0) -> {
        return existCastCase(v0);
    }, any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapSumTransCastToThenRule");

    /* loaded from: input_file:io/kyligence/kap/query/optrule/KapSumTransCastToThenRule$CastInfo.class */
    public static class CastInfo {
        private int index;
        private RelDataType columnType;
        private RelDataType castType;
        private boolean allConstants;

        public CastInfo(int i, RelDataType relDataType, RelDataType relDataType2, boolean z) {
            this.index = i;
            this.columnType = relDataType;
            this.castType = relDataType2;
            this.allConstants = z;
        }

        public int getIndex() {
            return this.index;
        }

        public RelDataType getColumnType() {
            return this.columnType;
        }

        public RelDataType getCastType() {
            return this.castType;
        }

        public boolean isAllConstants() {
            return this.allConstants;
        }
    }

    /* loaded from: input_file:io/kyligence/kap/query/optrule/KapSumTransCastToThenRule$InnerCastType.class */
    public enum InnerCastType {
        HAS_COLUMN_NOT_NUMBER(3),
        HAS_COLUMN_NUMBER(2),
        OTHER(1);

        private int weight;

        InnerCastType(int i) {
            this.weight = i;
        }
    }

    public static boolean existCastCase(Project project) {
        Iterator<RexNode> it2 = project.getChildExps().iterator();
        while (it2.hasNext()) {
            if (isCastCase(it2.next())) {
                return true;
            }
        }
        return false;
    }

    private static boolean isCastCase(RexNode rexNode) {
        return (rexNode instanceof RexCall) && SqlKind.CAST == rexNode.getKind() && SqlKind.CASE == ((RexCall) rexNode).operands.get(0).getKind();
    }

    public KapSumTransCastToThenRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return checkSumHasCaseCastInput(relOptRuleCall);
    }

    private boolean checkSumHasCaseCastInput(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        List<RexNode> childExps = ((Project) relOptRuleCall.rel(1)).getChildExps();
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < childExps.size(); i++) {
            if (isCastCase(childExps.get(i))) {
                newArrayList.add(Integer.valueOf(i));
            }
        }
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        for (int i2 = 0; i2 < aggCallList.size(); i2++) {
            if (checkAggNeedToRewrite(aggCallList.get(i2), newArrayList)) {
                return true;
            }
        }
        return false;
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        try {
            RelBuilder builder = relOptRuleCall.builder();
            Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
            Project project = (Project) relOptRuleCall.rel(1);
            switch (getCastType(project)) {
                case HAS_COLUMN_NOT_NUMBER:
                    return;
                case HAS_COLUMN_NUMBER:
                case OTHER:
                    if (canCaseType(project)) {
                        innerMatchNumericColumn(relOptRuleCall, builder, aggregate, project);
                        return;
                    }
                    return;
                default:
                    return;
            }
        } catch (Exception e) {
            logger.error("KapSumTransCastToThenRule apply failed", e);
        }
        logger.error("KapSumTransCastToThenRule apply failed", e);
    }

    private List<RexNode> getOperandsFromCaseWhen(RexNode rexNode) {
        return ((RexCall) ((RexCall) rexNode).getOperands().get(0)).getOperands();
    }

    private void innerMatchNumericColumn(RelOptRuleCall relOptRuleCall, RelBuilder relBuilder, Aggregate aggregate, Project project) {
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        relBuilder.push(project.getInput());
        List<RexNode> childExps = project.getChildExps();
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        for (int i = 0; i < childExps.size(); i++) {
            RexNode rexNode = childExps.get(i);
            RexNode rexNode2 = rexNode;
            if (isCastCase(rexNode)) {
                List<RexNode> operandsFromCaseWhen = getOperandsFromCaseWhen(rexNode);
                Set<RelDataType> allColumnType = getAllColumnType(operandsFromCaseWhen);
                RelDataType next = allColumnType.size() == 1 ? allColumnType.iterator().next() : null;
                newArrayList2.add(new CastInfo(i, next, rexNode.getType(), next == null));
                rexNode2 = rexBuilder.makeCall(next == null ? rexNode.getType() : next, SqlCaseOperator.INSTANCE, getCastedOperands(operandsFromCaseWhen, getCurCastType(operandsFromCaseWhen), rexNode.getType(), rexBuilder));
            }
            newArrayList.add(rexNode2);
        }
        relBuilder.project(newArrayList);
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        List<AggregateCall> newArrayList3 = Lists.newArrayList();
        ArrayList newArrayList4 = Lists.newArrayList();
        for (int i2 = 0; i2 < aggCallList.size(); i2++) {
            AggregateCall aggregateCall = aggCallList.get(i2);
            CastInfo castInfoForSum = getCastInfoForSum(aggregateCall, newArrayList2);
            if (castInfoForSum == null || castInfoForSum.isAllConstants()) {
                newArrayList3.add(aggregateCall);
            } else {
                newArrayList4.add(Integer.valueOf(i2));
                newArrayList3.add(createAggCall(aggregateCall, castInfoForSum));
            }
        }
        relBuilder.aggregate(relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()), newArrayList3);
        if (!newArrayList4.isEmpty()) {
            relBuilder.project(newProjectRexNodes(aggregate, relBuilder, newArrayList2));
        }
        relOptRuleCall.transformTo(relBuilder.build());
    }

    private AggregateCall createAggCall(AggregateCall aggregateCall, CastInfo castInfo) {
        return AggregateCall.create(aggregateCall.getAggregation(), aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, castInfo.getColumnType(), aggregateCall.name);
    }

    private CastInfo getCastInfoForSum(AggregateCall aggregateCall, List<CastInfo> list) {
        if (!AggExpressionUtil.isSum(aggregateCall.getAggregation().getKind())) {
            return null;
        }
        int intValue = aggregateCall.getArgList().get(0).intValue();
        for (CastInfo castInfo : list) {
            if (castInfo.getIndex() == intValue) {
                return castInfo;
            }
        }
        return null;
    }

    private boolean checkAggNeedToRewrite(AggregateCall aggregateCall, List<Integer> list) {
        return AggExpressionUtil.isSum(aggregateCall.getAggregation().getKind()) && list.contains(aggregateCall.getArgList().get(0));
    }

    private List<RexNode> newProjectRexNodes(Aggregate aggregate, RelBuilder relBuilder, List<CastInfo> list) {
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        ArrayList newArrayList = Lists.newArrayList();
        int i = 0;
        List<Integer> asList = aggregate.getGroupSet().asList();
        List<RelDataTypeField> fieldList = aggregate.getRowType().getFieldList();
        for (int i2 = 0; i2 < asList.size(); i2++) {
            newArrayList.add(rexBuilder.makeInputRef(fieldList.get(i2).getType(), asList.get(i2).intValue()));
            i++;
        }
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        RelNode peek = relBuilder.peek();
        for (int i3 = 0; i3 < aggCallList.size(); i3++) {
            CastInfo needCastForSum = needCastForSum(i3 + i, list);
            if (needCastForSum != null) {
                newArrayList.add(rexBuilder.makeCall(needCastForSum.getCastType(), new SqlCastFunction(), Lists.newArrayList(rexBuilder.makeInputRef(peek, needCastForSum.getIndex()))));
            } else {
                newArrayList.add(rexBuilder.makeInputRef(peek, i3 + i));
            }
        }
        return newArrayList;
    }

    private CastInfo needCastForSum(int i, List<CastInfo> list) {
        for (CastInfo castInfo : list) {
            if (castInfo.getIndex() == i && !castInfo.isAllConstants()) {
                return castInfo;
            }
        }
        return null;
    }

    private List<RexNode> getCastedOperands(List<RexNode> list, InnerCastType innerCastType, RelDataType relDataType, RexBuilder rexBuilder) {
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < list.size() - 1; i += 2) {
            newArrayList.add(list.get(i));
            newArrayList.add(transRexNode(list.get(i + 1), innerCastType, relDataType, rexBuilder));
        }
        if (list.size() % 2 == 1) {
            newArrayList.add(transRexNode(list.get(list.size() - 1), innerCastType, relDataType, rexBuilder));
        }
        return newArrayList;
    }

    private RexNode transRexNode(RexNode rexNode, InnerCastType innerCastType, RelDataType relDataType, RexBuilder rexBuilder) {
        return ((rexNode instanceof RexLiteral) && innerCastType == InnerCastType.OTHER) ? rexBuilder.makeCall(relDataType, new SqlCastFunction(), Lists.newArrayList(rexNode)) : rexNode;
    }

    private boolean canCaseType(Project project) {
        for (RexNode rexNode : project.getChildExps()) {
            if (isCastCase(rexNode)) {
                RelDataType type = rexNode.getType();
                Set<RelDataType> allColumnType = getAllColumnType(((RexCall) ((RexCall) rexNode).getOperands().get(0)).getOperands());
                if (!allColumnType.isEmpty() && (allColumnType.size() != 1 || !SqlTypeUtil.canCastFrom(allColumnType.iterator().next(), type, true))) {
                    return false;
                }
            }
        }
        return true;
    }

    private Set<RelDataType> getAllColumnType(List<RexNode> list) {
        HashSet newHashSet = Sets.newHashSet();
        for (int i = 0; i < list.size() - 1; i += 2) {
            RexNode rexNode = list.get(i + 1);
            if (isNumericColumn(rexNode)) {
                newHashSet.add(rexNode.getType());
            }
        }
        if (list.size() % 2 == 1) {
            RexNode rexNode2 = list.get(list.size() - 1);
            if (isNumericColumn(rexNode2)) {
                newHashSet.add(rexNode2.getType());
            }
        }
        return newHashSet;
    }

    private boolean isNumericColumn(RexNode rexNode) {
        return rexNode != null && (rexNode instanceof RexInputRef) && SqlTypeUtil.isNumeric(rexNode.getType());
    }

    private InnerCastType getCastType(Project project) {
        List<RexNode> childExps = project.getChildExps();
        InnerCastType innerCastType = InnerCastType.OTHER;
        for (RexNode rexNode : childExps) {
            if (isCastCase(rexNode)) {
                InnerCastType curCastType = getCurCastType(getOperandsFromCaseWhen(rexNode));
                innerCastType = innerCastType.weight > curCastType.weight ? innerCastType : curCastType;
            }
        }
        return innerCastType;
    }

    private InnerCastType getCurCastType(List<RexNode> list) {
        InnerCastType innerCastType = InnerCastType.OTHER;
        for (int i = 0; i < list.size() - 1; i += 2) {
            InnerCastType valueRexNodeType = getValueRexNodeType(list.get(i + 1));
            innerCastType = innerCastType.weight > valueRexNodeType.weight ? innerCastType : valueRexNodeType;
        }
        if (list.size() % 2 == 1) {
            InnerCastType valueRexNodeType2 = getValueRexNodeType(list.get(list.size() - 1));
            innerCastType = innerCastType.weight > valueRexNodeType2.weight ? innerCastType : valueRexNodeType2;
        }
        return innerCastType;
    }

    private InnerCastType getValueRexNodeType(RexNode rexNode) {
        InnerCastType innerCastType;
        if (rexNode instanceof RexInputRef) {
            innerCastType = SqlTypeUtil.isNumeric(rexNode.getType()) ? InnerCastType.HAS_COLUMN_NUMBER : InnerCastType.HAS_COLUMN_NOT_NUMBER;
        } else {
            innerCastType = InnerCastType.OTHER;
        }
        return innerCastType;
    }
}
