/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapFilterRel;
import org.apache.kylin.query.relnode.KapJoinRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.RuleUtils;

public class KapAggProjectTransposeRule
extends RelOptRule {
    public static final KapAggProjectTransposeRule AGG_PROJECT_FILTER_JOIN = new KapAggProjectTransposeRule(KapAggProjectTransposeRule.operand(KapAggregateRel.class, (RelOptRuleOperand)KapAggProjectTransposeRule.operand(KapProjectRel.class, (RelOptRuleOperand)KapAggProjectTransposeRule.operand(KapFilterRel.class, (RelOptRuleOperand)KapAggProjectTransposeRule.operand(KapJoinRel.class, (RelOptRuleOperandChildren)KapAggProjectTransposeRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapAggProjectTransposeRule:agg-project-filter-join");
    public static final KapAggProjectTransposeRule AGG_PROJECT_JOIN = new KapAggProjectTransposeRule(KapAggProjectTransposeRule.operand(KapAggregateRel.class, (RelOptRuleOperand)KapAggProjectTransposeRule.operand(KapProjectRel.class, (RelOptRuleOperand)KapAggProjectTransposeRule.operand(KapJoinRel.class, (RelOptRuleOperandChildren)KapAggProjectTransposeRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapAggProjectTransposeRule:agg-project-join");

    public KapAggProjectTransposeRule(RelOptRuleOperand operand) {
        super(operand);
    }

    public KapAggProjectTransposeRule(RelOptRuleOperand operand, String description) {
        super(operand, description);
    }

    public KapAggProjectTransposeRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall call) {
        KapAggregateRel aggregate = (KapAggregateRel)call.rel(0);
        KapProjectRel project = (KapProjectRel)call.rel(1);
        KapJoinRel joinRel = call.rel(2) instanceof KapFilterRel ? (KapJoinRel)call.rel(3) : (KapJoinRel)call.rel(2);
        if (!RuleUtils.isJoinOnlyOneAggChild(joinRel)) {
            return false;
        }
        HashSet argSet = Sets.newHashSet();
        int argCount = 0;
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            List argList = aggregateCall.getArgList();
            argCount += argList.size();
            argSet.addAll(argList);
        }
        if (argSet.size() != argCount) {
            return false;
        }
        for (int i = 0; i < project.getProjects().size(); ++i) {
            RexNode rexNode = (RexNode)project.getProjects().get(i);
            if (!(rexNode instanceof RexCall) || !aggregate.getRewriteGroupKeys().contains((Object)i)) continue;
            return true;
        }
        return false;
    }

    public void onMatch(RelOptRuleCall call) {
        KapAggregateRel aggregate = (KapAggregateRel)call.rel(0);
        KapProjectRel project = (KapProjectRel)call.rel(1);
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        Iterator iterator = aggregate.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int key = (Integer)iterator.next();
            RexNode rex = (RexNode)project.getProjects().get(key);
            if (rex instanceof RexInputRef) {
                int newKey = ((RexInputRef)rex).getIndex();
                builder.set(newKey);
                continue;
            }
            if (!(rex instanceof RexCall)) continue;
            this.getColumnsFromExpression((RexCall)rex, builder);
        }
        ImmutableBitSet newGroupSet = builder.build();
        LinkedHashSet<Integer> mappingWithOrder = new LinkedHashSet<Integer>();
        mappingWithOrder.addAll(newGroupSet.asList());
        for (RexNode rexNode : project.getProjects()) {
            if (rexNode instanceof RexInputRef) {
                int index = ((RexInputRef)rexNode).getIndex();
                if (mappingWithOrder.contains(index)) continue;
                mappingWithOrder.add(index);
                continue;
            }
            if (!(rexNode instanceof RexCall)) continue;
            this.getColumnsFromProjects((RexCall)rexNode, mappingWithOrder);
        }
        ArrayList mappingWithOrderList = Lists.newArrayList(mappingWithOrder);
        RelNode projectInput = project.getInput();
        Mappings.TargetMapping mapping = Mappings.target(mappingWithOrderList::indexOf, (int)projectInput.getRowType().getFieldCount(), (int)mappingWithOrder.size());
        ImmutableList.Builder aggCalls = org.apache.kylin.guava30.shaded.common.collect.ImmutableList.builder();
        ImmutableList.Builder topAggCalls = org.apache.kylin.guava30.shaded.common.collect.ImmutableList.builder();
        HashMap<Integer, RelDataType> countArgMap = new HashMap<Integer, RelDataType>();
        int newAggregateGroupSetSize = newGroupSet.asSet().size();
        ArrayList projects = Lists.newArrayList();
        HashSet<Integer> newTopAggregateSet = new HashSet<Integer>();
        int start = 0;
        for (Object index : aggregate.getGroupSet().asSet()) {
            projects.add(project.getProjects().get((Integer)index));
            newTopAggregateSet.add(start);
            ++start;
        }
        ArrayList newProjects = Lists.newArrayList();
        for (RexNode rexNode : RexUtil.apply((Mappings.TargetMapping)mapping, (Iterable)projects)) {
            newProjects.add(rexNode);
        }
        this.processAggCalls(aggregate, project, (ImmutableList.Builder<AggregateCall>)aggCalls, (ImmutableList.Builder<AggregateCall>)topAggCalls, countArgMap, newProjects);
        org.apache.kylin.guava30.shaded.common.collect.ImmutableList aggregateCalls = aggCalls.build();
        Aggregate newAggregate = aggregate.copy(aggregate.getTraitSet(), project.getInput(), aggregate.indicator, newGroupSet, null, (List)aggregateCalls);
        ArrayList newProjectFieldNames = new ArrayList();
        List oldProjectFieldNameList = project.getRowType().getFieldNames();
        for (int i = 0; i < oldProjectFieldNameList.size(); ++i) {
            if (!aggregate.getGroupSet().get(i)) continue;
            newProjectFieldNames.add(oldProjectFieldNameList.get(i));
        }
        for (Map.Entry entry : countArgMap.entrySet()) {
            int originalRefIndex = (Integer)entry.getKey();
            int newRefIndex = newAggregateGroupSetSize + originalRefIndex;
            RelDataType relDataType = (RelDataType)entry.getValue();
            String projectRefName = "$" + newRefIndex;
            newProjects.add(new RexInputRef(projectRefName, newRefIndex, relDataType));
            newProjectFieldNames.add(projectRefName);
        }
        RelDataType newRowType = RexUtil.createStructType((RelDataTypeFactory)newAggregate.getCluster().getTypeFactory(), (List)newProjects, newProjectFieldNames, (SqlValidatorUtil.Suggester)SqlValidatorUtil.F_SUGGESTER);
        Project newProject = project.copy(project.getTraitSet(), (RelNode)newAggregate, (List)newProjects, newRowType);
        ImmutableBitSet.Builder topAggregateGroupSetBuilder = ImmutableBitSet.builder();
        topAggregateGroupSetBuilder.addAll(newTopAggregateSet);
        Aggregate topAggregate = aggregate.copy(aggregate.getTraitSet(), (RelNode)newProject, aggregate.indicator, topAggregateGroupSetBuilder.build(), null, (List)topAggCalls.build());
        call.transformTo((RelNode)topAggregate);
    }

    private void processAggCalls(KapAggregateRel aggregate, KapProjectRel project, ImmutableList.Builder<AggregateCall> aggCalls, ImmutableList.Builder<AggregateCall> topAggCalls, Map<Integer, RelDataType> countArgMap, List<RexNode> newProjects) {
        int startIndex = 0;
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            ImmutableList.Builder newArgs = org.apache.kylin.guava30.shaded.common.collect.ImmutableList.builder();
            Iterator iterator = aggregateCall.getArgList().iterator();
            while (iterator.hasNext()) {
                int arg = (Integer)iterator.next();
                RexNode rex = (RexNode)project.getProjects().get(arg);
                if (rex instanceof RexInputRef) {
                    newArgs.add((Object)((RexInputRef)rex).getIndex());
                    continue;
                }
                return;
            }
            int newFilterArg = -1;
            if (aggregateCall.filterArg >= 0 && project.getProjects().get(aggregateCall.filterArg) instanceof RexInputRef) {
                newFilterArg = ((RexInputRef)project.getProjects().get(aggregateCall.filterArg)).getIndex();
            }
            aggCalls.add((Object)aggregateCall.copy((List)newArgs.build(), newFilterArg));
            countArgMap.put(startIndex, aggregateCall.type);
            ArrayList<Integer> topAggArgList = new ArrayList<Integer>();
            topAggArgList.add(newProjects.size() + startIndex);
            if (!aggregateCall.getAggregation().getName().equals("COUNT")) {
                topAggCalls.add((Object)AggregateCall.create((SqlAggFunction)aggregateCall.getAggregation(), (boolean)false, (boolean)false, topAggArgList, (int)-1, (RelDataType)aggregateCall.type, (String)aggregateCall.name));
            } else {
                topAggCalls.add((Object)AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM0, (boolean)false, (boolean)false, topAggArgList, (int)-1, (RelDataType)aggregateCall.type, (String)aggregateCall.name));
            }
            ++startIndex;
        }
    }

    private void getColumnsFromExpression(RexCall rexCall, ImmutableBitSet.Builder builder) {
        ImmutableList rexNodes = rexCall.operands;
        for (RexNode rexNode : rexNodes) {
            if (rexNode instanceof RexInputRef) {
                builder.set(((RexInputRef)rexNode).getIndex());
                continue;
            }
            if (!(rexNode instanceof RexCall)) continue;
            this.getColumnsFromExpression((RexCall)rexNode, builder);
        }
    }

    private void getColumnsFromProjects(RexCall rexCall, Set<Integer> mapping) {
        ImmutableList rexNodes = rexCall.operands;
        for (RexNode rexNode : rexNodes) {
            if (rexNode instanceof RexInputRef) {
                mapping.add(((RexInputRef)rexNode).getIndex());
                continue;
            }
            if (!(rexNode instanceof RexCall)) continue;
            this.getColumnsFromProjects((RexCall)rexNode, mapping);
        }
    }
}

