package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataType;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlKind;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.BasicSqlType;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kyligence/kap/query/optrule/KapAggSumCastRule.class */
public class KapAggSumCastRule extends RelOptRule {
    public static final KapAggSumCastRule INSTANCE = new KapAggSumCastRule(operand(KapAggregateRel.class, operand(KapProjectRel.class, null, kapProjectRel -> {
        return !AggExpressionUtil.hasAggInput(kapProjectRel);
    }, RelOptRule.any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapAggSumCastRule");
    private static final Logger logger = LoggerFactory.getLogger(KapAggSumCastRule.class);

    public KapAggSumCastRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return true;
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        AggregateCall aggregateCall;
        HashMap hashMap = new HashMap();
        Map<AggregateCall, AggregateCall> hashMap2 = new HashMap<>();
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        Project project = (Project) relOptRuleCall.rel(1);
        boolean z = false;
        for (AggregateCall aggregateCall2 : aggregate.getAggCallList()) {
            if (SqlKind.SUM.name().equalsIgnoreCase(aggregateCall2.getAggregation().getKind().name())) {
                z = true;
                List<Integer> argList = aggregateCall2.getArgList();
                if (argList.size() == 1) {
                    hashMap.put(argList.get(0), aggregateCall2);
                }
            }
        }
        if (z) {
            boolean z2 = false;
            SqlTypeFactoryImpl sqlTypeFactoryImpl = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
            LinkedList linkedList = new LinkedList();
            LinkedList linkedList2 = new LinkedList();
            List<RexNode> childExps = project.getChildExps();
            Set<Integer> asSet = aggregate.getGroupSet().asSet();
            for (int i = 0; i < childExps.size(); i++) {
                AggregateCall aggregateCall3 = (AggregateCall) hashMap.get(Integer.valueOf(i));
                RexNode rexNode = childExps.get(i);
                if (aggregateCall3 == null) {
                    linkedList.add(rexNode);
                } else {
                    RexNode rexNode2 = rexNode;
                    if ((rexNode instanceof RexCall) && (((RexCall) rexNode).op instanceof SqlCastFunction)) {
                        RexCall rexCall = (RexCall) rexNode;
                        List<RexNode> operands = rexCall.getOperands();
                        if (operands.size() != 1) {
                            linkedList.add(rexNode);
                        } else {
                            RexNode rexNode3 = operands.get(0);
                            if (SqlTypeName.DOUBLE == rexCall.getType().getSqlTypeName() && SqlTypeFamily.NUMERIC == rexNode3.getType().getSqlTypeName().getFamily()) {
                                z2 = true;
                                RexNode rexNode4 = ((RexCall) rexNode2).getOperands().get(0);
                                RelDataType type = rexNode4.getType();
                                if (SqlTypeName.INTEGER == rexNode4.getType().getSqlTypeName() || SqlTypeName.SMALLINT == rexNode4.getType().getSqlTypeName() || SqlTypeName.TINYINT == rexNode4.getType().getSqlTypeName()) {
                                    type = sqlTypeFactoryImpl.createTypeWithNullability(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), true);
                                }
                                if (asSet.contains(Integer.valueOf(i))) {
                                    aggregateCall = new AggregateCall(aggregateCall3.getAggregation(), false, Arrays.asList(Integer.valueOf(childExps.size() + linkedList2.size())), type, aggregateCall3.getName());
                                    linkedList2.add(rexNode4);
                                } else {
                                    aggregateCall = new AggregateCall(aggregateCall3.getAggregation(), false, aggregateCall3.getArgList(), type, aggregateCall3.getName());
                                    rexNode2 = rexNode4;
                                }
                                hashMap2.put(aggregateCall3, aggregateCall);
                            }
                        }
                    }
                    linkedList.add(rexNode2);
                }
            }
            if (z2) {
                linkedList.addAll(linkedList2);
                RelBuilder builder = relOptRuleCall.builder();
                builder.push(project.getInput());
                builder.project(linkedList);
                List<AggregateCall> arrayList = new ArrayList<>(aggregate.getAggCallList().size());
                aggregate.getAggCallList().forEach(aggregateCall4 -> {
                    AggregateCall aggregateCall4 = (AggregateCall) hashMap2.get(aggregateCall4);
                    if (aggregateCall4 != null) {
                        arrayList.add(aggregateCall4);
                    } else {
                        arrayList.add(aggregateCall4);
                    }
                });
                builder.aggregate(builder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()), arrayList);
                builder.project(buildTopProject(builder, aggregate, hashMap2));
                relOptRuleCall.transformTo(builder.build());
            }
        }
    }

    private List<RexNode> buildTopProject(RelBuilder relBuilder, Aggregate aggregate, Map<AggregateCall, AggregateCall> map) {
        ArrayList newArrayList = Lists.newArrayList();
        int size = aggregate.getGroupSet().asSet().size();
        for (int i = 0; i < size; i++) {
            newArrayList.add(relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i));
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            AggregateCall aggregateCall2 = map.get(aggregateCall);
            int size2 = newArrayList.size();
            newArrayList.add(aggregateCall2 != null ? relBuilder.getRexBuilder().makeCast(aggregateCall.getType(), relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), size2)) : relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), size2));
        }
        return newArrayList;
    }
}
