/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.kylin.guava30.shaded.common.base.Preconditions;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Maps;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapJoinRel;
import org.apache.kylin.query.util.RuleUtils;

public class KapAggJoinTransposeRule
extends RelOptRule {
    private static final String STAR_TOKEN = "*";
    public static final KapAggJoinTransposeRule INSTANCE_JOIN_RIGHT_AGG = new KapAggJoinTransposeRule(KapAggJoinTransposeRule.operand(KapAggregateRel.class, (RelOptRuleOperand)KapAggJoinTransposeRule.operand(KapJoinRel.class, (RelOptRuleOperandChildren)KapAggJoinTransposeRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapAggJoinTransposeRule:agg-join-rightAgg");

    public KapAggJoinTransposeRule(RelOptRuleOperand operand) {
        super(operand);
    }

    public KapAggJoinTransposeRule(RelOptRuleOperand operand, String description) {
        super(operand, description);
    }

    public KapAggJoinTransposeRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall call) {
        KapAggregateRel aggregate = (KapAggregateRel)call.rel(0);
        KapJoinRel joinRel = (KapJoinRel)call.rel(1);
        return !aggregate.isContainCountDistinct() && RuleUtils.isJoinOnlyOneAggChild(joinRel);
    }

    public void onMatch(RelOptRuleCall call) {
        KapAggregateRel aggregate = (KapAggregateRel)call.rel(0);
        KapJoinRel join = (KapJoinRel)call.rel(1);
        RelBuilder relBuilder = call.builder();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) != null && aggregateCall.filterArg < 0) continue;
            return;
        }
        if (join.getJoinType() != JoinRelType.INNER && join.getJoinType() != JoinRelType.LEFT) {
            return;
        }
        ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
        RelMetadataQuery mq = call.getMetadataQuery();
        ImmutableBitSet keyColumns = KapAggJoinTransposeRule.keyColumns(aggregateColumns, (ImmutableList<RexNode>)mq.getPulledUpPredicates((RelNode)join).pulledUpPredicates);
        ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits((RexNode)join.getCondition());
        boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
        ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
        ArrayList leftKeys = Lists.newArrayList();
        ArrayList rightKeys = Lists.newArrayList();
        ArrayList filterNulls = Lists.newArrayList();
        RexNode nonEquiConj = RelOptUtil.splitJoinCondition((RelNode)join.getLeft(), (RelNode)join.getRight(), (RexNode)join.getCondition(), (List)leftKeys, (List)rightKeys, (List)filterNulls);
        if (!nonEquiConj.isAlwaysTrue()) {
            return;
        }
        this.aggPushDown(aggregate, join, belowAggregateColumns, mq, relBuilder, allColumnsInAggregate);
        call.transformTo(relBuilder.build());
    }

    private void aggPushDown(KapAggregateRel aggregate, KapJoinRel join, ImmutableBitSet belowAggregateColumns, RelMetadataQuery mq, RelBuilder relBuilder, boolean allColumnsInAggregate) {
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        ArrayList<Side> sides = new ArrayList<Side>();
        int uniqueCount = 0;
        int offset = 0;
        int belowOffset = 0;
        for (int s = 0; s < 2; ++s) {
            boolean unique;
            Side side = new Side();
            RelNode joinInput = join.getInput(s);
            int fieldCount = joinInput.getRowType().getFieldCount();
            ImmutableBitSet fieldSet = ImmutableBitSet.range((int)offset, (int)(offset + fieldCount));
            ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
            for (Ord c : Ord.zip((Iterable)belowAggregateKeyNotShifted)) {
                map.put((Integer)c.e, belowOffset + c.i);
            }
            Mappings.IdentityMapping mapping = s == 0 ? Mappings.createIdentity((int)fieldCount) : Mappings.createShiftMapping((int)(fieldCount + offset), (int[])new int[]{0, offset, fieldCount});
            ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
            Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            boolean bl = unique = unique0 != null && unique0 != false;
            if (unique) {
                ++uniqueCount;
                this.processUnique(side, relBuilder, joinInput, aggregate, fieldSet, (Mappings.TargetMapping)mapping, belowAggregateKey);
            } else {
                this.processUnUnique(side, aggregate, relBuilder, joinInput, fieldSet, (Mappings.TargetMapping)mapping, belowAggregateKey);
            }
            offset += fieldCount;
            belowOffset += side.newInput.getRowType().getFieldCount();
            sides.add(side);
        }
        if (uniqueCount == 2) {
            return;
        }
        KapAggJoinTransposeRule.updateCondition(sides, map, aggregate, join, belowOffset, relBuilder, allColumnsInAggregate);
    }

    private void processUnique(Side side, RelBuilder relBuilder, RelNode joinInput, KapAggregateRel aggregate, ImmutableBitSet fieldSet, Mappings.TargetMapping mapping, ImmutableBitSet belowAggregateKey) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        side.aggregate = false;
        relBuilder.push(joinInput);
        ArrayList<Object> projects = new ArrayList<Object>();
        for (Integer i : belowAggregateKey) {
            projects.add(relBuilder.field(i.intValue()));
        }
        for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
            SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
            if (((AggregateCall)aggCall.e).getArgList().isEmpty() || !fieldSet.contains(ImmutableBitSet.of((Iterable)((AggregateCall)aggCall.e).getArgList()))) continue;
            RexNode singleton = splitter.singleton(rexBuilder, joinInput.getRowType(), ((AggregateCall)aggCall.e).transform(mapping));
            if (singleton instanceof RexInputRef) {
                side.split.put(aggCall.i, ((RexInputRef)singleton).getIndex());
                continue;
            }
            projects.add(singleton);
            side.split.put(aggCall.i, projects.size() - 1);
        }
        relBuilder.project(projects);
        side.newInput = relBuilder.build();
    }

    private void processUnUnique(Side side, KapAggregateRel aggregate, RelBuilder relBuilder, RelNode joinInput, ImmutableBitSet fieldSet, Mappings.TargetMapping mapping, ImmutableBitSet belowAggregateKey) {
        side.aggregate = true;
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ArrayList belowAggCalls = new ArrayList();
        SqlSplittableAggFunction.Registry belowAggCallRegistry = KapAggJoinTransposeRule.registry(belowAggCalls);
        int oldGroupKeyCount = aggregate.getGroupCount();
        int newGroupKeyCount = belowAggregateKey.cardinality();
        for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
            AggregateCall call1;
            SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
            if (fieldSet.contains(ImmutableBitSet.of((Iterable)((AggregateCall)aggCall.e).getArgList()))) {
                AggregateCall splitCall = splitter.split((AggregateCall)aggCall.e, mapping);
                call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
            } else {
                call1 = splitter.other(rexBuilder.getTypeFactory(), (AggregateCall)aggCall.e);
            }
            if (call1 == null) continue;
            side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register((Object)call1));
        }
        side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
    }

    private static void updateCondition(List<Side> sides, Map<Integer, Integer> map, KapAggregateRel aggregate, KapJoinRel join, int belowOffset, RelBuilder relBuilder, boolean allColumnsInAggregate) {
        Mapping mapping = (Mapping)Mappings.target(map::get, (int)join.getRowType().getFieldCount(), (int)belowOffset);
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RexNode newCondition = RexUtil.apply((Mappings.TargetMapping)mapping, (RexNode)join.getCondition());
        relBuilder.push(sides.get((int)0).newInput).push(sides.get((int)1).newInput).join(join.getJoinType(), newCondition);
        ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
        int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        int newLeftWidth = sides.get((int)0).newInput.getRowType().getFieldCount();
        ArrayList<RexNode> projects = new ArrayList<RexNode>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
        HashMap<Integer, Boolean> leftJoinAggCallMap = new HashMap<Integer, Boolean>();
        for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
            SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
            Integer leftSubTotal = sides.get((int)0).split.get(aggCall.i);
            Integer rightSubTotal = sides.get((int)1).split.get(aggCall.i);
            newAggCalls.add(splitter.topSplit(rexBuilder, KapAggJoinTransposeRule.registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), (AggregateCall)aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
            if (join.getJoinType() != JoinRelType.LEFT) continue;
            leftJoinAggCallMap.put(projects.size() - 1, KapAggJoinTransposeRule.isLeftAgg((AggregateCall)aggCall.getValue(), join));
        }
        List<RexNode> newProjects = KapAggJoinTransposeRule.createNewProjects(rexBuilder, projects, leftJoinAggCallMap);
        relBuilder.project(newProjects);
        boolean aggConvertedToProjects = false;
        if (allColumnsInAggregate) {
            ArrayList<Object> projects2 = new ArrayList<Object>();
            Iterator iterator = Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet()).iterator();
            while (iterator.hasNext()) {
                int key = (Integer)iterator.next();
                projects2.add(relBuilder.field(key));
            }
            for (AggregateCall newAggCall : newAggCalls) {
                SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                if (splitter == null) continue;
                RelDataType rowType = relBuilder.peek().getRowType();
                projects2.add(splitter.singleton(rexBuilder, rowType, newAggCall));
            }
            if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                relBuilder.project(projects2);
                aggConvertedToProjects = true;
            }
        }
        if (!aggConvertedToProjects) {
            relBuilder.aggregate(relBuilder.groupKey(Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet()), Mappings.apply2((Mapping)mapping, (Iterable)aggregate.getGroupSets())), newAggCalls);
        }
    }

    private static List<RexNode> createNewProjects(RexBuilder rexBuilder, List<RexNode> projects, Map<Integer, Boolean> leftJoinAggCallMap) {
        HashMap projectsMap = Maps.newHashMap();
        for (RexNode rexNode : projects) {
            if (!(rexNode instanceof RexInputRef)) continue;
            RexInputRef rexInputRef = (RexInputRef)rexNode;
            projectsMap.put(rexInputRef.getIndex(), rexInputRef.getType());
        }
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>();
        for (int i = 0; i < projects.size(); ++i) {
            RexNode rexNode = projects.get(i);
            if (leftJoinAggCallMap.get(i) == null) {
                newProjects.add(rexNode);
                continue;
            }
            newProjects.add(KapAggJoinTransposeRule.rewriteRexNode(rexNode, rexBuilder, leftJoinAggCallMap.get(i), projectsMap));
        }
        return newProjects;
    }

    private static RexNode rewriteRexNode(RexNode rexNode, RexBuilder rexBuilder, boolean isLeft, Map<Integer, RelDataType> projectsRelDataTypeMap) {
        if (rexNode instanceof RexCall) {
            RelDataType dataType = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createSqlType(SqlTypeName.INTEGER);
            RexCall rexCall = (RexCall)rexNode;
            SqlOperator sqlOperator = rexCall.getOperator();
            if (KapAggJoinTransposeRule.isMultiplicationRexCall(rexCall)) {
                List<RexNode> rewriteRexNodeList = KapAggJoinTransposeRule.rewriteRexNodeList(rexCall, rexBuilder, dataType, isLeft, projectsRelDataTypeMap);
                return rexBuilder.makeCall(rexCall.type, (SqlOperator)SqlStdOperatorTable.MULTIPLY, rewriteRexNodeList);
            }
            if (sqlOperator.getKind() == SqlKind.CAST && rexCall.getOperands().size() == 1 && rexCall.getOperands().get(0) instanceof RexCall && KapAggJoinTransposeRule.isMultiplicationRexCall((RexCall)rexCall.getOperands().get(0))) {
                RexCall innerRexCall = (RexCall)rexCall.getOperands().get(0);
                List<RexNode> rewriteRexNodeList = KapAggJoinTransposeRule.rewriteRexNodeList(innerRexCall, rexBuilder, dataType, isLeft, projectsRelDataTypeMap);
                RexNode rewriteInnerRexCall = rexBuilder.makeCall(innerRexCall.type, (SqlOperator)SqlStdOperatorTable.MULTIPLY, rewriteRexNodeList);
                return rexBuilder.makeCast(rexCall.type, rewriteInnerRexCall);
            }
        }
        return rexNode;
    }

    private static List<RexNode> rewriteRexNodeList(RexCall rexCall, RexBuilder rexBuilder, RelDataType dataType, boolean isLeft, Map<Integer, RelDataType> projectsRelDataTypeMap) {
        List<RexNode> rexNodeList = rexCall.getOperands();
        ArrayList<RexNode> rewriteRexNodeList = new ArrayList<RexNode>();
        rexNodeList = KapAggJoinTransposeRule.rewriteRefInputList(rexNodeList, projectsRelDataTypeMap, rexBuilder);
        if (isLeft) {
            rewriteRexNodeList.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.COALESCE, new RexNode[]{rexNodeList.get(0), rexBuilder.makeLiteral((Object)1, dataType, false)}));
            rewriteRexNodeList.add(rexNodeList.get(1));
        } else {
            rewriteRexNodeList.add(rexNodeList.get(0));
            rewriteRexNodeList.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.COALESCE, new RexNode[]{rexNodeList.get(1), rexBuilder.makeLiteral((Object)1, dataType, false)}));
        }
        return rewriteRexNodeList;
    }

    private static List<RexNode> rewriteRefInputList(List<RexNode> rexNodeList, Map<Integer, RelDataType> projectsRelDataTypeMap, RexBuilder rexBuilder) {
        ArrayList<RexNode> rewriteRexNodeList = new ArrayList<RexNode>();
        for (RexNode rexNode : rexNodeList) {
            RexInputRef rexInputRef;
            RelDataType originalRelDataType;
            if (rexNode instanceof RexInputRef && (originalRelDataType = projectsRelDataTypeMap.get((rexInputRef = (RexInputRef)rexNode).getIndex())) != null && rexInputRef.getType() != originalRelDataType) {
                rexNode = rexBuilder.makeInputRef(originalRelDataType, rexInputRef.getIndex());
            }
            rewriteRexNodeList.add(rexNode);
        }
        return rewriteRexNodeList;
    }

    private static boolean isMultiplicationRexCall(RexCall rexCall) {
        return rexCall.getOperator().getName().equals(STAR_TOKEN) && rexCall.getOperands().size() == 2;
    }

    private static boolean isLeftAgg(AggregateCall aggregateCall, KapJoinRel joinRel) {
        List argList = aggregateCall.getArgList();
        if (argList.isEmpty()) {
            return true;
        }
        int maxIndex = (Integer)argList.get(0);
        for (Integer index : argList) {
            if (maxIndex >= index) continue;
            maxIndex = index;
        }
        return maxIndex >= joinRel.getLeft().getRowType().getFieldList().size();
    }

    private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
        TreeMap<Integer, BitSet> equivalence = new TreeMap<Integer, BitSet>();
        for (RexNode predicate : predicates) {
            KapAggJoinTransposeRule.populateEquivalences(equivalence, predicate);
        }
        ImmutableBitSet keyColumns = aggregateColumns;
        for (Integer aggregateColumn : aggregateColumns) {
            BitSet bitSet = (BitSet)equivalence.get(aggregateColumn);
            if (bitSet == null) continue;
            keyColumns = keyColumns.union(bitSet);
        }
        return keyColumns;
    }

    private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
        if (predicate.getKind() != SqlKind.EQUALS) {
            return;
        }
        RexCall call = (RexCall)predicate;
        List operands = call.getOperands();
        if (operands.get(0) instanceof RexInputRef) {
            RexInputRef ref0 = (RexInputRef)operands.get(0);
            if (operands.get(1) instanceof RexInputRef) {
                RexInputRef ref1 = (RexInputRef)operands.get(1);
                KapAggJoinTransposeRule.populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
                KapAggJoinTransposeRule.populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
            }
        }
    }

    private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
        BitSet bitSet = equivalence.computeIfAbsent(i0, bitset -> new BitSet());
        bitSet.set(i1);
    }

    private static <E> SqlSplittableAggFunction.Registry<E> registry(List<E> list) {
        return e -> {
            int i = list.indexOf(e);
            if (i < 0) {
                i = list.size();
                list.add(e);
            }
            return i;
        };
    }

    private static class Side {
        final Map<Integer, Integer> split = new HashMap<Integer, Integer>();
        RelNode newInput;
        boolean aggregate;

        private Side() {
        }
    }
}

