/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlCaseOperator;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KapSumTransCastToThenRule
extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(KapSumTransCastToThenRule.class);
    public static final KapSumTransCastToThenRule INSTANCE = new KapSumTransCastToThenRule(KapSumTransCastToThenRule.operand(KapAggregateRel.class, (RelOptRuleOperand)KapSumTransCastToThenRule.operand(KapProjectRel.class, null, KapSumTransCastToThenRule::existCastCase, (RelOptRuleOperandChildren)KapSumTransCastToThenRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapSumTransCastToThenRule");

    public static boolean existCastCase(Project logicalProject) {
        List childExps = logicalProject.getChildExps();
        for (RexNode rexNode : childExps) {
            if (!KapSumTransCastToThenRule.isCastCase(rexNode)) continue;
            return true;
        }
        return false;
    }

    private static boolean isCastCase(RexNode rexNode) {
        if (!(rexNode instanceof RexCall)) {
            return false;
        }
        if (SqlKind.CAST != rexNode.getKind()) {
            return false;
        }
        return SqlKind.CASE == ((RexNode)((RexCall)rexNode).operands.get(0)).getKind();
    }

    public KapSumTransCastToThenRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall call) {
        return this.checkSumHasCaseCastInput(call);
    }

    private boolean checkSumHasCaseCastInput(RelOptRuleCall call) {
        Aggregate logicalAggregate = (Aggregate)call.rel(0);
        Project logicalProject = (Project)call.rel(1);
        List projectExps = logicalProject.getChildExps();
        ArrayList castIndexs = Lists.newArrayList();
        for (int i = 0; i < projectExps.size(); ++i) {
            RexNode curProExp = (RexNode)projectExps.get(i);
            if (!KapSumTransCastToThenRule.isCastCase(curProExp)) continue;
            castIndexs.add(i);
        }
        List aggCalls = logicalAggregate.getAggCallList();
        for (int i = 0; i < aggCalls.size(); ++i) {
            if (!this.checkAggNeedToRewrite((AggregateCall)aggCalls.get(i), castIndexs)) continue;
            return true;
        }
        return false;
    }

    public void onMatch(RelOptRuleCall call) {
        try {
            RelBuilder relBuilder = call.builder();
            Aggregate logicalAggregate = (Aggregate)call.rel(0);
            Project logicalProject = (Project)call.rel(1);
            switch (this.getCastType(logicalProject)) {
                case HAS_COLUMN_NOT_NUMBER: {
                    return;
                }
                case HAS_COLUMN_NUMBER: 
                case OTHER: {
                    if (this.canCaseType(logicalProject)) {
                        this.innerMatchNumericColumn(call, relBuilder, logicalAggregate, logicalProject);
                    }
                    break;
                }
                default: {
                    return;
                }
            }
        }
        catch (Exception e) {
            logger.error("KapSumTransCastToThenRule apply failed", (Throwable)e);
        }
    }

    private List<RexNode> getOperandsFromCaseWhen(RexNode curProExp) {
        RexNode caseWhenRexNode = (RexNode)((RexCall)curProExp).getOperands().get(0);
        return ((RexCall)caseWhenRexNode).getOperands();
    }

    private void innerMatchNumericColumn(RelOptRuleCall call, RelBuilder relBuilder, Aggregate logicalAggregate, Project logicalProject) {
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        relBuilder.push(logicalProject.getInput());
        List projectExps = logicalProject.getChildExps();
        ArrayList projectRexNodes = Lists.newArrayList();
        ArrayList castInfos = Lists.newArrayList();
        for (int i = 0; i < projectExps.size(); ++i) {
            RexNode curProExp;
            RexNode newRexNode = curProExp = (RexNode)projectExps.get(i);
            if (KapSumTransCastToThenRule.isCastCase(curProExp)) {
                List<RexNode> operands = this.getOperandsFromCaseWhen(curProExp);
                Set<RelDataType> allColumnType = this.getAllColumnType(operands);
                RelDataType columnType = allColumnType.size() == 1 ? allColumnType.iterator().next() : null;
                castInfos.add(new CastInfo(i, columnType, curProExp.getType(), columnType == null));
                List<RexNode> castedOperands = this.getCastedOperands(operands, this.getCurCastType(operands), curProExp.getType(), rexBuilder);
                newRexNode = rexBuilder.makeCall(columnType == null ? curProExp.getType() : columnType, (SqlOperator)SqlCaseOperator.INSTANCE, castedOperands);
            }
            projectRexNodes.add(newRexNode);
        }
        relBuilder.project((Iterable)projectRexNodes);
        List aggCalls = logicalAggregate.getAggCallList();
        ArrayList newAggs = Lists.newArrayList();
        ArrayList needCastSumIndex = Lists.newArrayList();
        for (int i = 0; i < aggCalls.size(); ++i) {
            AggregateCall curAgg = (AggregateCall)aggCalls.get(i);
            CastInfo curCastInfo = this.getCastInfoForSum(curAgg, castInfos);
            if (curCastInfo != null && !curCastInfo.isAllConstants()) {
                needCastSumIndex.add(i);
                newAggs.add(this.createAggCall(curAgg, curCastInfo));
                continue;
            }
            newAggs.add(curAgg);
        }
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(logicalAggregate.getGroupSet(), logicalAggregate.getGroupSets());
        relBuilder.aggregate(groupKey, (List)newAggs);
        if (!needCastSumIndex.isEmpty()) {
            relBuilder.project(this.newProjectRexNodes(logicalAggregate, relBuilder, castInfos));
        }
        call.transformTo(relBuilder.build());
    }

    private AggregateCall createAggCall(AggregateCall curAgg, CastInfo curCastInfo) {
        return AggregateCall.create((SqlAggFunction)curAgg.getAggregation(), (boolean)curAgg.isDistinct(), (boolean)curAgg.isApproximate(), (List)curAgg.getArgList(), (int)curAgg.filterArg, (RelDataType)curCastInfo.getColumnType(), (String)curAgg.name);
    }

    private CastInfo getCastInfoForSum(AggregateCall call, List<CastInfo> castInfos) {
        if (!AggExpressionUtil.isSum(call.getAggregation().getKind())) {
            return null;
        }
        int input = (Integer)call.getArgList().get(0);
        for (CastInfo castInfo : castInfos) {
            if (castInfo.getIndex() != input) continue;
            return castInfo;
        }
        return null;
    }

    private boolean checkAggNeedToRewrite(AggregateCall call, List<Integer> castIndexs) {
        return AggExpressionUtil.isSum(call.getAggregation().getKind()) && castIndexs.contains(call.getArgList().get(0));
    }

    private List<RexNode> newProjectRexNodes(Aggregate logicalAggregate, RelBuilder relBuilder, List<CastInfo> castInfos) {
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        ArrayList projectRexNodes = Lists.newArrayList();
        int offset = 0;
        List groups = logicalAggregate.getGroupSet().asList();
        List fieldList = logicalAggregate.getRowType().getFieldList();
        for (int i = 0; i < groups.size(); ++i) {
            int index = (Integer)groups.get(i);
            projectRexNodes.add(rexBuilder.makeInputRef(((RelDataTypeField)fieldList.get(i)).getType(), index));
            ++offset;
        }
        List aggCalls = logicalAggregate.getAggCallList();
        RelNode peekedRelNodes = relBuilder.peek();
        for (int i = 0; i < aggCalls.size(); ++i) {
            CastInfo castInfo = this.needCastForSum(i + offset, castInfos);
            if (castInfo != null) {
                projectRexNodes.add(rexBuilder.makeCall(castInfo.getCastType(), (SqlOperator)new SqlCastFunction(), (List)Lists.newArrayList((Object[])new RexNode[]{rexBuilder.makeInputRef(peekedRelNodes, castInfo.getIndex())})));
                continue;
            }
            projectRexNodes.add(rexBuilder.makeInputRef(peekedRelNodes, i + offset));
        }
        return projectRexNodes;
    }

    private CastInfo needCastForSum(int i, List<CastInfo> castInfos) {
        for (CastInfo castInfo : castInfos) {
            if (castInfo.getIndex() != i || castInfo.isAllConstants()) continue;
            return castInfo;
        }
        return null;
    }

    private List<RexNode> getCastedOperands(List<RexNode> operands, InnerCastType curCastType, RelDataType castType, RexBuilder rexBuilder) {
        ArrayList castedOperands = Lists.newArrayList();
        for (int i = 0; i < operands.size() - 1; i += 2) {
            castedOperands.add(operands.get(i));
            castedOperands.add(this.transRexNode(operands.get(i + 1), curCastType, castType, rexBuilder));
        }
        if (operands.size() % 2 == 1) {
            castedOperands.add(this.transRexNode(operands.get(operands.size() - 1), curCastType, castType, rexBuilder));
        }
        return castedOperands;
    }

    private RexNode transRexNode(RexNode valueRexNode, InnerCastType curCastType, RelDataType castType, RexBuilder rexBuilder) {
        if (valueRexNode instanceof RexLiteral && curCastType == InnerCastType.OTHER) {
            return rexBuilder.makeCall(castType, (SqlOperator)new SqlCastFunction(), (List)Lists.newArrayList((Object[])new RexNode[]{valueRexNode}));
        }
        return valueRexNode;
    }

    private boolean canCaseType(Project logicalProject) {
        List childExps = logicalProject.getChildExps();
        for (RexNode rexNode : childExps) {
            if (!KapSumTransCastToThenRule.isCastCase(rexNode)) continue;
            RelDataType castReturnType = rexNode.getType();
            RexNode caseWhenRexNode = (RexNode)((RexCall)rexNode).getOperands().get(0);
            List operands = ((RexCall)caseWhenRexNode).getOperands();
            Set<RelDataType> columnsDataType = this.getAllColumnType(operands);
            if (columnsDataType.isEmpty()) continue;
            if (columnsDataType.size() != 1) {
                return false;
            }
            if (SqlTypeUtil.canCastFrom((RelDataType)columnsDataType.iterator().next(), (RelDataType)castReturnType, (boolean)true)) continue;
            return false;
        }
        return true;
    }

    private Set<RelDataType> getAllColumnType(List<RexNode> operands) {
        RexNode valueRexNode;
        HashSet columnsDataType = Sets.newHashSet();
        for (int i = 0; i < operands.size() - 1; i += 2) {
            valueRexNode = operands.get(i + 1);
            if (!this.isNumericColumn(valueRexNode)) continue;
            columnsDataType.add(valueRexNode.getType());
        }
        if (operands.size() % 2 == 1 && this.isNumericColumn(valueRexNode = operands.get(operands.size() - 1))) {
            columnsDataType.add(valueRexNode.getType());
        }
        return columnsDataType;
    }

    private boolean isNumericColumn(RexNode valueRexNode) {
        if (valueRexNode == null) {
            return false;
        }
        return valueRexNode instanceof RexInputRef && SqlTypeUtil.isNumeric((RelDataType)valueRexNode.getType());
    }

    private InnerCastType getCastType(Project logicalProject) {
        List childExps = logicalProject.getChildExps();
        InnerCastType castType = InnerCastType.OTHER;
        for (RexNode rexNode : childExps) {
            if (!KapSumTransCastToThenRule.isCastCase(rexNode)) continue;
            InnerCastType cur = this.getCurCastType(this.getOperandsFromCaseWhen(rexNode));
            castType = castType.weight > cur.weight ? castType : cur;
        }
        return castType;
    }

    private InnerCastType getCurCastType(List<RexNode> operands) {
        InnerCastType cur;
        InnerCastType castType = InnerCastType.OTHER;
        for (int i = 0; i < operands.size() - 1; i += 2) {
            cur = this.getValueRexNodeType(operands.get(i + 1));
            castType = castType.weight > cur.weight ? castType : cur;
        }
        if (operands.size() % 2 == 1) {
            cur = this.getValueRexNodeType(operands.get(operands.size() - 1));
            castType = castType.weight > cur.weight ? castType : cur;
        }
        return castType;
    }

    private InnerCastType getValueRexNodeType(RexNode valueRexNode) {
        InnerCastType cur = valueRexNode instanceof RexInputRef ? (SqlTypeUtil.isNumeric((RelDataType)valueRexNode.getType()) ? InnerCastType.HAS_COLUMN_NUMBER : InnerCastType.HAS_COLUMN_NOT_NUMBER) : InnerCastType.OTHER;
        return cur;
    }

    public static class CastInfo {
        private int index;
        private RelDataType columnType;
        private RelDataType castType;
        private boolean allConstants;

        public CastInfo(int index, RelDataType columnType, RelDataType castType, boolean allConstants) {
            this.index = index;
            this.columnType = columnType;
            this.castType = castType;
            this.allConstants = allConstants;
        }

        public int getIndex() {
            return this.index;
        }

        public RelDataType getColumnType() {
            return this.columnType;
        }

        public RelDataType getCastType() {
            return this.castType;
        }

        public boolean isAllConstants() {
            return this.allConstants;
        }
    }

    public static enum InnerCastType {
        HAS_COLUMN_NOT_NUMBER(3),
        HAS_COLUMN_NUMBER(2),
        OTHER(1);

        private int weight;

        private InnerCastType(int weight) {
            this.weight = weight;
        }
    }
}

