package io.kyligence.kap.query.optrule;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.job.shaded.com.google.common.collect.ImmutableList;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.JoinRelType;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlAggFunction;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.metadata.datatype.DataType;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.apache.kylin.query.util.RuleUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.support.PropertiesBeanDefinitionReader;

/* loaded from: input_file:io/kyligence/kap/query/optrule/AbstractAggCaseWhenFunctionRule.class */
public abstract class AbstractAggCaseWhenFunctionRule extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(AbstractAggCaseWhenFunctionRule.class);
    private static final String BOTTOM_AGG_PREFIX = "SUB_AGG$";

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractAggCaseWhenFunctionRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return checkAggCaseExpression((Aggregate) relOptRuleCall.rel(0), (Project) relOptRuleCall.rel(1));
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        try {
            RelBuilder builder = relOptRuleCall.builder();
            Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
            Project project = (Project) relOptRuleCall.rel(1);
            List<AggregateCall> list = (List) aggregate.getAggCallList().stream().filter(aggregateCall -> {
                return isApplicableWithSumCaseRule(aggregateCall, project);
            }).collect(Collectors.toList());
            LinkedList linkedList = new LinkedList(aggregate.getAggCallList());
            linkedList.removeAll(list);
            Aggregate extractPartialAggregateCalls = extractPartialAggregateCalls(builder, aggregate, project, list);
            RelNode transformSumExprAggregate = transformSumExprAggregate(builder, extractPartialAggregateCalls, (Project) extractPartialAggregateCalls.getInput(0));
            if (linkedList.isEmpty()) {
                relOptRuleCall.transformTo(transformSumExprAggregate);
                return;
            }
            RelNode joinAggCaseWhenAndNonAggCaseWhenRel = joinAggCaseWhenAndNonAggCaseWhenRel(builder, transformSumExprAggregate, extractPartialAggregateCalls(builder, aggregate, project, linkedList), aggregate);
            ContextUtil.dumpCalcitePlan("new plan", joinAggCaseWhenAndNonAggCaseWhenRel, logger);
            relOptRuleCall.transformTo(joinAggCaseWhenAndNonAggCaseWhenRel);
        } catch (Error | Exception e) {
            logger.error("sql cannot apply sum case when rule ", e);
        }
    }

    private RelNode joinAggCaseWhenAndNonAggCaseWhenRel(RelBuilder relBuilder, RelNode relNode, Aggregate aggregate, Aggregate aggregate2) {
        relBuilder.push(aggregate);
        relBuilder.push(relNode);
        List<RelDataTypeField> fieldList = aggregate.getRowType().getFieldList();
        List<RelDataTypeField> fieldList2 = relNode.getRowType().getFieldList();
        int size = fieldList.size() - aggregate.getAggCallList().size();
        LinkedList linkedList = new LinkedList();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        for (int i = 0; i < size; i++) {
            linkedList.add(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexBuilder.makeInputRef(fieldList.get(i).getType(), i), rexBuilder.makeInputRef(fieldList.get(i).getType(), i + fieldList.size())));
        }
        relBuilder.join(JoinRelType.INNER, linkedList);
        LinkedList linkedList2 = new LinkedList();
        for (int i2 = 0; i2 < size; i2++) {
            linkedList2.add(rexBuilder.makeInputRef(fieldList.get(i2).getType(), i2));
        }
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < aggregate2.getAggCallList().size(); i5++) {
            if (i3 >= aggregate.getAggCallList().size() || aggregate2.getAggCallList().get(i5) != aggregate.getAggCallList().get(i3)) {
                linkedList2.add(rexBuilder.makeInputRef(fieldList2.get(size + i4).getType(), fieldList.size() + size + i4));
                i4++;
            } else {
                linkedList2.add(rexBuilder.makeInputRef(fieldList.get(size + i3).getType(), size + i3));
                i3++;
            }
        }
        relBuilder.project(linkedList2);
        return relBuilder.build();
    }

    private RelNode cloneRelNode(RelNode relNode) {
        return relNode instanceof HepRelVertex ? cloneRelNode(((HepRelVertex) relNode).getCurrentRel()) : relNode.copy(relNode.getTraitSet(), (List) relNode.getInputs().stream().map(this::cloneRelNode).collect(Collectors.toList()));
    }

    private Aggregate extractPartialAggregateCalls(RelBuilder relBuilder, Aggregate aggregate, Project project, List<AggregateCall> list) {
        Set set = (Set) list.stream().flatMap(aggregateCall -> {
            return aggregateCall.getArgList().stream();
        }).collect(Collectors.toSet());
        set.addAll(aggregate.getGroupSet().asList());
        relBuilder.push(cloneRelNode(project.getInput()));
        ArrayList arrayList = new ArrayList(project.getChildExps().size());
        for (int i = 0; i < project.getChildExps().size(); i++) {
            if (set.contains(Integer.valueOf(i))) {
                arrayList.add(project.getChildExps().get(i));
            } else {
                arrayList.add(relBuilder.getRexBuilder().makeZeroLiteral(project.getChildExps().get(i).getType()));
            }
        }
        relBuilder.project(arrayList);
        relBuilder.aggregate(relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()), list);
        return (Aggregate) relBuilder.build();
    }

    private RelNode transformSumExprAggregate(RelBuilder relBuilder, Aggregate aggregate, Project project) {
        relBuilder.push(project.getInput());
        List<AggExpressionUtil.AggExpression> collectSumExpressions = AggExpressionUtil.collectSumExpressions(aggregate, project);
        List<AggExpressionUtil.AggExpression> list = (List) collectSumExpressions.stream().filter(this::isApplicableAggExpression).collect(Collectors.toList());
        Pair<List<AggExpressionUtil.GroupExpression>, ImmutableList<ImmutableBitSet>> collectGroupExprAndGroup = AggExpressionUtil.collectGroupExprAndGroup(aggregate, project);
        List<AggExpressionUtil.GroupExpression> first = collectGroupExprAndGroup.getFirst();
        ImmutableList<ImmutableBitSet> second = collectGroupExprAndGroup.getSecond();
        relBuilder.project(buildBottomProject(relBuilder, project, first, collectSumExpressions));
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        for (AggExpressionUtil.GroupExpression groupExpression : first) {
            for (int i = 0; i < groupExpression.getBottomAggInput().length; i++) {
                builder.set(groupExpression.getBottomAggInput()[i]);
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            for (int i2 = 0; i2 < aggExpression.getBottomAggConditionsInput().length; i2++) {
                builder.set(aggExpression.getBottomAggConditionsInput()[i2]);
            }
        }
        ImmutableBitSet build = builder.build();
        relBuilder.aggregate(relBuilder.groupKey(build, null), buildBottomAggregate(relBuilder, collectSumExpressions, build.cardinality()));
        for (AggExpressionUtil.GroupExpression groupExpression2 : first) {
            for (int i3 = 0; i3 < groupExpression2.getTopProjInput().length; i3++) {
                groupExpression2.getTopProjInput()[i3] = build.indexOf(groupExpression2.getBottomAggInput()[i3]);
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression2 : list) {
            for (int i4 = 0; i4 < aggExpression2.getTopProjConditionsInput().length; i4++) {
                aggExpression2.getTopProjConditionsInput()[i4] = build.indexOf(aggExpression2.getBottomAggConditionsInput()[i4]);
            }
        }
        relBuilder.project(buildTopProject(relBuilder, project, collectSumExpressions, first));
        ImmutableBitSet.Builder builder2 = ImmutableBitSet.builder();
        for (int i5 = 0; i5 < first.size(); i5++) {
            builder2.set(i5);
        }
        ImmutableBitSet build2 = builder2.build();
        relBuilder.aggregate(relBuilder.groupKey(build2, second), buildTopAggregate(aggregate.getAggCallList(), build2.cardinality(), collectSumExpressions));
        RelNode build3 = relBuilder.build();
        ContextUtil.dumpCalcitePlan("new plan", build3, logger);
        return build3;
    }

    private List<RexNode> buildBottomProject(RelBuilder relBuilder, Project project, List<AggExpressionUtil.GroupExpression> list, List<AggExpressionUtil.AggExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        for (AggExpressionUtil.GroupExpression groupExpression : list) {
            int[] bottomProjInput = groupExpression.getBottomProjInput();
            for (int i = 0; i < bottomProjInput.length; i++) {
                groupExpression.getBottomAggInput()[i] = newArrayList.size();
                newArrayList.add(rexBuilder.makeInputRef(project.getInput(), bottomProjInput[i]));
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression : list2) {
            if (isApplicableAggExpression(aggExpression)) {
                buildBottomAggExpression(rexBuilder, project, newArrayList, aggExpression);
            } else if (aggExpression.getExpression() != null) {
                aggExpression.getBottomAggInput()[0] = newArrayList.size();
                newArrayList.add(aggExpression.getExpression());
            }
        }
        return newArrayList;
    }

    private void buildBottomAggExpression(RexBuilder rexBuilder, Project project, List<RexNode> list, AggExpressionUtil.AggExpression aggExpression) {
        int[] bottomProjConditionsInput = aggExpression.getBottomProjConditionsInput();
        for (int i = 0; i < bottomProjConditionsInput.length; i++) {
            aggExpression.getBottomAggConditionsInput()[i] = list.size();
            list.add(rexBuilder.makeInputRef(project.getInput(), bottomProjConditionsInput[i]));
        }
        List<RexNode> valuesList = aggExpression.getValuesList();
        for (int i2 = 0; i2 < valuesList.size(); i2++) {
            aggExpression.getBottomAggValuesInput()[i2] = list.size();
            if (RuleUtils.isCast(valuesList.get(i2))) {
                RexNode rexNode = ((RexCall) valuesList.get(i2)).operands.get(0);
                DataType type = DataType.getType(rexNode.getType().getSqlTypeName().getName());
                if (!AggExpressionUtil.isSum(aggExpression.getAggCall().getAggregation().kind) || type.isNumberFamily() || type.isIntegerFamily()) {
                    list.add(rexNode);
                } else {
                    list.add(valuesList.get(i2));
                }
            } else if (RuleUtils.isNotNullLiteral(valuesList.get(i2))) {
                list.add(valuesList.get(i2));
            } else {
                list.add(rexBuilder.makeBigintLiteral(BigDecimal.ZERO));
            }
        }
    }

    private List<AggregateCall> buildBottomAggregate(RelBuilder relBuilder, List<AggExpressionUtil.AggExpression> list, int i) {
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList<AggExpressionUtil.AggExpression> newArrayList2 = Lists.newArrayList();
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            if (isApplicableAggExpression(aggExpression)) {
                newArrayList2.add(aggExpression);
            } else {
                aggExpression.getTopProjInput()[0] = i + newArrayList.size();
                AggregateCall aggCall = aggExpression.getAggCall();
                newArrayList.add(aggCall.copy((List) Arrays.stream(aggExpression.getBottomAggInput()).boxed().collect(Collectors.toList()), aggCall.filterArg));
            }
        }
        int i2 = 0;
        for (AggExpressionUtil.AggExpression aggExpression2 : newArrayList2) {
            for (int i3 = 0; i3 < aggExpression2.getValuesList().size(); i3++) {
                if (isValidAggColumnExpr(aggExpression2.getValuesList().get(i3))) {
                    String str = getBottomAggPrefix() + i2 + PropertiesBeanDefinitionReader.CONSTRUCTOR_ARG_PREFIX + i3;
                    ArrayList newArrayList3 = Lists.newArrayList(Integer.valueOf(aggExpression2.getBottomAggValuesInput()[i3]));
                    aggExpression2.getTopProjValuesInput()[i3] = i + newArrayList.size();
                    newArrayList.add(AggregateCall.create(getBottomAggFunc(aggExpression2.getAggCall()), false, false, newArrayList3, -1, i, relBuilder.peek(), null, str));
                }
            }
            i2++;
        }
        return newArrayList;
    }

    private List<RexNode> buildTopProject(RelBuilder relBuilder, Project project, List<AggExpressionUtil.AggExpression> list, List<AggExpressionUtil.GroupExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (AggExpressionUtil.GroupExpression groupExpression : list2) {
            newArrayList.add(relBuilder.getRexBuilder().ensureType(groupExpression.getExpression().getType(), (RexNode) groupExpression.getExpression().accept(new RelOptUtil.RexInputConverter(relBuilder.getRexBuilder(), project.getInput().getRowType().getFieldList(), relBuilder.peek().getRowType().getFieldList(), AggExpressionUtil.generateAdjustments(groupExpression.getBottomProjInput(), groupExpression.getTopProjInput()))), false));
        }
        for (AggExpressionUtil.AggExpression aggExpression : list) {
            if (isApplicableAggExpression(aggExpression)) {
                int[] generateAdjustments = AggExpressionUtil.generateAdjustments(aggExpression.getBottomProjConditionsInput(), aggExpression.getTopProjConditionsInput());
                List<RexNode> conditions = aggExpression.getConditions();
                List<RexNode> valuesList = aggExpression.getValuesList();
                ArrayList newArrayList2 = Lists.newArrayList();
                int i = 0;
                while (i < conditions.size()) {
                    newArrayList2.add((RexNode) conditions.get(i).accept(new RelOptUtil.RexInputConverter(relBuilder.getRexBuilder(), project.getInput().getRowType().getFieldList(), relBuilder.peek().getRowType().getFieldList(), generateAdjustments)));
                    RexNode rexNode = valuesList.get(i);
                    if (isNeedTackCast(rexNode)) {
                        rexNode = relBuilder.getRexBuilder().makeCast(((RexCall) rexNode).type, relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[i]));
                    } else if (RuleUtils.isNotNullLiteral(rexNode)) {
                        rexNode = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[i]);
                    }
                    newArrayList2.add(rexNode);
                    i++;
                }
                RexNode rexNode2 = valuesList.get(i);
                if (isNeedTackCast(rexNode2)) {
                    rexNode2 = relBuilder.getRexBuilder().makeCast(((RexCall) rexNode2).type, relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[i]));
                } else if (RuleUtils.isNotNullLiteral(rexNode2)) {
                    rexNode2 = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[i]);
                }
                newArrayList2.add(rexNode2);
                newArrayList.add(relBuilder.call(SqlStdOperatorTable.CASE, (Iterable<? extends RexNode>) newArrayList2));
            } else {
                newArrayList.add(relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjInput()[0]));
            }
        }
        return newArrayList;
    }

    private List<AggregateCall> buildTopAggregate(List<AggregateCall> list, int i, List<AggExpressionUtil.AggExpression> list2) {
        ArrayList newArrayList = Lists.newArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            AggregateCall aggCall = list2.get(i2).getAggCall();
            newArrayList.add(AggregateCall.create(getTopAggFunc(aggCall), false, false, (List<Integer>) Lists.newArrayList(Integer.valueOf(i + i2)), -1, aggCall.getType(), "AGG$" + i2));
        }
        return newArrayList;
    }

    protected abstract boolean checkAggCaseExpression(Aggregate aggregate, Project project);

    protected abstract boolean isApplicableWithSumCaseRule(AggregateCall aggregateCall, Project project);

    protected abstract boolean isApplicableAggExpression(AggExpressionUtil.AggExpression aggExpression);

    protected abstract SqlAggFunction getBottomAggFunc(AggregateCall aggregateCall);

    protected abstract SqlAggFunction getTopAggFunc(AggregateCall aggregateCall);

    protected boolean isValidAggColumnExpr(RexNode rexNode) {
        return true;
    }

    protected boolean isNeedTackCast(RexNode rexNode) {
        return RuleUtils.isCast(rexNode);
    }

    protected String getBottomAggPrefix() {
        return BOTTOM_AGG_PREFIX;
    }
}
