package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.kylin.guava30.shaded.common.base.Preconditions;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Maps;
import org.apache.kylin.job.shaded.com.google.common.base.Function;
import org.apache.kylin.job.shaded.com.google.common.collect.ImmutableList;
import org.apache.kylin.job.shaded.com.google.common.collect.UnmodifiableIterator;
import org.apache.kylin.job.shaded.org.apache.calcite.linq4j.Ord;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.JoinRelType;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataType;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexInputRef;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlKind;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlOperator;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.job.shaded.org.apache.calcite.util.mapping.Mapping;
import org.apache.kylin.job.shaded.org.apache.calcite.util.mapping.Mappings;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapJoinRel;
import org.apache.kylin.query.util.RuleUtils;

/* loaded from: input_file:io/kyligence/kap/query/optrule/KapAggJoinTransposeRule.class */
public class KapAggJoinTransposeRule extends RelOptRule {
    private static final String STAR_TOKEN = "*";
    public static final KapAggJoinTransposeRule INSTANCE_JOIN_RIGHT_AGG = new KapAggJoinTransposeRule(operand(KapAggregateRel.class, operand(KapJoinRel.class, any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "KapAggJoinTransposeRule:agg-join-rightAgg");

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/kyligence/kap/query/optrule/KapAggJoinTransposeRule$Side.class */
    public static class Side {
        final Map<Integer, Integer> split;
        RelNode newInput;
        boolean aggregate;

        private Side() {
            this.split = new HashMap();
        }
    }

    public KapAggJoinTransposeRule(RelOptRuleOperand relOptRuleOperand) {
        super(relOptRuleOperand);
    }

    public KapAggJoinTransposeRule(RelOptRuleOperand relOptRuleOperand, String str) {
        super(relOptRuleOperand, str);
    }

    public KapAggJoinTransposeRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return !((KapAggregateRel) relOptRuleCall.rel(0)).isContainCountDistinct() && RuleUtils.isJoinOnlyOneAggChild((KapJoinRel) relOptRuleCall.rel(1));
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        KapAggregateRel kapAggregateRel = (KapAggregateRel) relOptRuleCall.rel(0);
        KapJoinRel kapJoinRel = (KapJoinRel) relOptRuleCall.rel(1);
        RelBuilder builder = relOptRuleCall.builder();
        for (AggregateCall aggregateCall : kapAggregateRel.getAggCallList()) {
            if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null || aggregateCall.filterArg >= 0) {
                return;
            }
        }
        if (kapJoinRel.getJoinType() == JoinRelType.INNER || kapJoinRel.getJoinType() == JoinRelType.LEFT) {
            ImmutableBitSet groupSet = kapAggregateRel.getGroupSet();
            RelMetadataQuery metadataQuery = relOptRuleCall.getMetadataQuery();
            ImmutableBitSet keyColumns = keyColumns(groupSet, metadataQuery.getPulledUpPredicates(kapJoinRel).pulledUpPredicates);
            ImmutableBitSet bits = RelOptUtil.InputFinder.bits(kapJoinRel.getCondition());
            boolean contains = keyColumns.contains(bits);
            ImmutableBitSet union = groupSet.union(bits);
            if (RelOptUtil.splitJoinCondition(kapJoinRel.getLeft(), kapJoinRel.getRight(), kapJoinRel.getCondition(), Lists.newArrayList(), Lists.newArrayList(), Lists.newArrayList()).isAlwaysTrue()) {
                aggPushDown(kapAggregateRel, kapJoinRel, union, metadataQuery, builder, contains);
                relOptRuleCall.transformTo(builder.build());
            }
        }
    }

    private void aggPushDown(KapAggregateRel kapAggregateRel, KapJoinRel kapJoinRel, ImmutableBitSet immutableBitSet, RelMetadataQuery relMetadataQuery, RelBuilder relBuilder, boolean z) {
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        while (i4 < 2) {
            Side side = new Side();
            RelNode input = kapJoinRel.getInput(i4);
            int fieldCount = input.getRowType().getFieldCount();
            ImmutableBitSet range = ImmutableBitSet.range(i2, i2 + fieldCount);
            ImmutableBitSet intersect = immutableBitSet.intersect(range);
            for (Ord ord : Ord.zip(intersect)) {
                hashMap.put(ord.e, Integer.valueOf(i3 + ord.i));
            }
            Mappings.TargetMapping createIdentity = i4 == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + i2, 0, i2, fieldCount);
            ImmutableBitSet shift = intersect.shift(-i2);
            Boolean areColumnsUnique = relMetadataQuery.areColumnsUnique(input, shift);
            if (areColumnsUnique != null && areColumnsUnique.booleanValue()) {
                i++;
                processUnique(side, relBuilder, input, kapAggregateRel, range, createIdentity, shift);
            } else {
                processUnUnique(side, kapAggregateRel, relBuilder, input, range, createIdentity, shift);
            }
            i2 += fieldCount;
            i3 += side.newInput.getRowType().getFieldCount();
            arrayList.add(side);
            i4++;
        }
        if (i == 2) {
            return;
        }
        updateCondition(arrayList, hashMap, kapAggregateRel, kapJoinRel, i3, relBuilder, z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void processUnique(Side side, RelBuilder relBuilder, RelNode relNode, KapAggregateRel kapAggregateRel, ImmutableBitSet immutableBitSet, Mappings.TargetMapping targetMapping, ImmutableBitSet immutableBitSet2) {
        RexBuilder rexBuilder = kapAggregateRel.getCluster().getRexBuilder();
        side.aggregate = false;
        relBuilder.push(relNode);
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it2 = immutableBitSet2.iterator();
        while (it2.hasNext()) {
            arrayList.add(relBuilder.field(it2.next().intValue()));
        }
        for (Ord ord : Ord.zip((List) kapAggregateRel.getAggCallList())) {
            SqlSplittableAggFunction sqlSplittableAggFunction = (SqlSplittableAggFunction) Preconditions.checkNotNull(((AggregateCall) ord.e).getAggregation().unwrap(SqlSplittableAggFunction.class));
            if (!((AggregateCall) ord.e).getArgList().isEmpty() && immutableBitSet.contains(ImmutableBitSet.of(((AggregateCall) ord.e).getArgList()))) {
                RexNode singleton = sqlSplittableAggFunction.singleton(rexBuilder, relNode.getRowType(), ((AggregateCall) ord.e).transform(targetMapping));
                if (singleton instanceof RexInputRef) {
                    side.split.put(Integer.valueOf(ord.i), Integer.valueOf(((RexInputRef) singleton).getIndex()));
                } else {
                    arrayList.add(singleton);
                    side.split.put(Integer.valueOf(ord.i), Integer.valueOf(arrayList.size() - 1));
                }
            }
        }
        relBuilder.project(arrayList);
        side.newInput = relBuilder.build();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void processUnUnique(Side side, KapAggregateRel kapAggregateRel, RelBuilder relBuilder, RelNode relNode, ImmutableBitSet immutableBitSet, Mappings.TargetMapping targetMapping, ImmutableBitSet immutableBitSet2) {
        AggregateCall other;
        side.aggregate = true;
        RexBuilder rexBuilder = kapAggregateRel.getCluster().getRexBuilder();
        ArrayList arrayList = new ArrayList();
        SqlSplittableAggFunction.Registry registry = registry(arrayList);
        int groupCount = kapAggregateRel.getGroupCount();
        int cardinality = immutableBitSet2.cardinality();
        for (Ord ord : Ord.zip((List) kapAggregateRel.getAggCallList())) {
            SqlSplittableAggFunction sqlSplittableAggFunction = (SqlSplittableAggFunction) Preconditions.checkNotNull(((AggregateCall) ord.e).getAggregation().unwrap(SqlSplittableAggFunction.class));
            if (immutableBitSet.contains(ImmutableBitSet.of(((AggregateCall) ord.e).getArgList()))) {
                AggregateCall split = sqlSplittableAggFunction.split((AggregateCall) ord.e, targetMapping);
                other = split.adaptTo(relNode, split.getArgList(), split.filterArg, groupCount, cardinality);
            } else {
                other = sqlSplittableAggFunction.other(rexBuilder.getTypeFactory(), (AggregateCall) ord.e);
            }
            if (other != null) {
                side.split.put(Integer.valueOf(ord.i), Integer.valueOf(immutableBitSet2.cardinality() + registry.register(other)));
            }
        }
        side.newInput = relBuilder.push(relNode).aggregate(relBuilder.groupKey(immutableBitSet2, null), (List<AggregateCall>) arrayList).build();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static void updateCondition(List<Side> list, Map<Integer, Integer> map, KapAggregateRel kapAggregateRel, KapJoinRel kapJoinRel, int i, RelBuilder relBuilder, boolean z) {
        map.getClass();
        Mapping mapping = (Mapping) Mappings.target((Function<Integer, Integer>) (v1) -> {
            return r0.get(v1);
        }, kapJoinRel.getRowType().getFieldCount(), i);
        RexBuilder rexBuilder = kapAggregateRel.getCluster().getRexBuilder();
        relBuilder.push(list.get(0).newInput).push(list.get(1).newInput).join(kapJoinRel.getJoinType(), RexUtil.apply(mapping, kapJoinRel.getCondition()));
        List<AggregateCall> arrayList = new ArrayList<>();
        int groupCount = kapAggregateRel.getGroupCount() + kapAggregateRel.getIndicatorCount();
        int fieldCount = list.get(0).newInput.getRowType().getFieldCount();
        ArrayList arrayList2 = new ArrayList(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
        HashMap hashMap = new HashMap();
        for (Ord ord : Ord.zip((List) kapAggregateRel.getAggCallList())) {
            SqlSplittableAggFunction sqlSplittableAggFunction = (SqlSplittableAggFunction) Preconditions.checkNotNull(((AggregateCall) ord.e).getAggregation().unwrap(SqlSplittableAggFunction.class));
            Integer num = list.get(0).split.get(Integer.valueOf(ord.i));
            Integer num2 = list.get(1).split.get(Integer.valueOf(ord.i));
            arrayList.add(sqlSplittableAggFunction.topSplit(rexBuilder, registry(arrayList2), groupCount, relBuilder.peek().getRowType(), (AggregateCall) ord.e, num == null ? -1 : num.intValue(), num2 == null ? -1 : num2.intValue() + fieldCount));
            if (kapJoinRel.getJoinType() == JoinRelType.LEFT) {
                hashMap.put(Integer.valueOf(arrayList2.size() - 1), Boolean.valueOf(isLeftAgg((AggregateCall) ord.getValue(), kapJoinRel)));
            }
        }
        relBuilder.project(createNewProjects(rexBuilder, arrayList2, hashMap));
        boolean z2 = false;
        if (z) {
            ArrayList arrayList3 = new ArrayList();
            Iterator<Integer> it2 = Mappings.apply(mapping, kapAggregateRel.getGroupSet()).iterator();
            while (it2.hasNext()) {
                arrayList3.add(relBuilder.field(it2.next().intValue()));
            }
            for (AggregateCall aggregateCall : arrayList) {
                SqlSplittableAggFunction sqlSplittableAggFunction2 = (SqlSplittableAggFunction) aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                if (sqlSplittableAggFunction2 != null) {
                    arrayList3.add(sqlSplittableAggFunction2.singleton(rexBuilder, relBuilder.peek().getRowType(), aggregateCall));
                }
            }
            if (arrayList3.size() == kapAggregateRel.getGroupSet().cardinality() + arrayList.size()) {
                relBuilder.project(arrayList3);
                z2 = true;
            }
        }
        if (z2) {
            return;
        }
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, kapAggregateRel.getGroupSet()), Mappings.apply2(mapping, (Iterable<ImmutableBitSet>) kapAggregateRel.getGroupSets())), arrayList);
    }

    private static List<RexNode> createNewProjects(RexBuilder rexBuilder, List<RexNode> list, Map<Integer, Boolean> map) {
        HashMap newHashMap = Maps.newHashMap();
        for (RexNode rexNode : list) {
            if (rexNode instanceof RexInputRef) {
                RexInputRef rexInputRef = (RexInputRef) rexNode;
                newHashMap.put(Integer.valueOf(rexInputRef.getIndex()), rexInputRef.getType());
            }
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            RexNode rexNode2 = list.get(i);
            if (map.get(Integer.valueOf(i)) == null) {
                arrayList.add(rexNode2);
            } else {
                arrayList.add(rewriteRexNode(rexNode2, rexBuilder, map.get(Integer.valueOf(i)).booleanValue(), newHashMap));
            }
        }
        return arrayList;
    }

    private static RexNode rewriteRexNode(RexNode rexNode, RexBuilder rexBuilder, boolean z, Map<Integer, RelDataType> map) {
        if (rexNode instanceof RexCall) {
            RelDataType createSqlType = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createSqlType(SqlTypeName.INTEGER);
            RexCall rexCall = (RexCall) rexNode;
            SqlOperator operator = rexCall.getOperator();
            if (isMultiplicationRexCall(rexCall)) {
                return rexBuilder.makeCall(rexCall.type, SqlStdOperatorTable.MULTIPLY, rewriteRexNodeList(rexCall, rexBuilder, createSqlType, z, map));
            }
            if (operator.getKind() == SqlKind.CAST && rexCall.getOperands().size() == 1 && (rexCall.getOperands().get(0) instanceof RexCall) && isMultiplicationRexCall((RexCall) rexCall.getOperands().get(0))) {
                RexCall rexCall2 = (RexCall) rexCall.getOperands().get(0);
                return rexBuilder.makeCast(rexCall.type, rexBuilder.makeCall(rexCall2.type, SqlStdOperatorTable.MULTIPLY, rewriteRexNodeList(rexCall2, rexBuilder, createSqlType, z, map)));
            }
        }
        return rexNode;
    }

    private static List<RexNode> rewriteRexNodeList(RexCall rexCall, RexBuilder rexBuilder, RelDataType relDataType, boolean z, Map<Integer, RelDataType> map) {
        List<RexNode> operands = rexCall.getOperands();
        ArrayList arrayList = new ArrayList();
        List<RexNode> rewriteRefInputList = rewriteRefInputList(operands, map, rexBuilder);
        if (z) {
            arrayList.add(rexBuilder.makeCall(SqlStdOperatorTable.COALESCE, rewriteRefInputList.get(0), rexBuilder.makeLiteral((Object) 1, relDataType, false)));
            arrayList.add(rewriteRefInputList.get(1));
        } else {
            arrayList.add(rewriteRefInputList.get(0));
            arrayList.add(rexBuilder.makeCall(SqlStdOperatorTable.COALESCE, rewriteRefInputList.get(1), rexBuilder.makeLiteral((Object) 1, relDataType, false)));
        }
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode] */
    private static List<RexNode> rewriteRefInputList(List<RexNode> list, Map<Integer, RelDataType> map, RexBuilder rexBuilder) {
        ArrayList arrayList = new ArrayList();
        for (RexInputRef rexInputRef : list) {
            if (rexInputRef instanceof RexInputRef) {
                RexInputRef rexInputRef2 = rexInputRef;
                RelDataType relDataType = map.get(Integer.valueOf(rexInputRef2.getIndex()));
                if (relDataType != null && rexInputRef2.getType() != relDataType) {
                    rexInputRef = rexBuilder.makeInputRef(relDataType, rexInputRef2.getIndex());
                }
            }
            arrayList.add(rexInputRef);
        }
        return arrayList;
    }

    private static boolean isMultiplicationRexCall(RexCall rexCall) {
        return rexCall.getOperator().getName().equals("*") && rexCall.getOperands().size() == 2;
    }

    private static boolean isLeftAgg(AggregateCall aggregateCall, KapJoinRel kapJoinRel) {
        List<Integer> argList = aggregateCall.getArgList();
        if (argList.isEmpty()) {
            return true;
        }
        int intValue = argList.get(0).intValue();
        for (Integer num : argList) {
            if (intValue < num.intValue()) {
                intValue = num.intValue();
            }
        }
        return intValue >= kapJoinRel.getLeft().getRowType().getFieldList().size();
    }

    private static ImmutableBitSet keyColumns(ImmutableBitSet immutableBitSet, ImmutableList<RexNode> immutableList) {
        TreeMap treeMap = new TreeMap();
        UnmodifiableIterator<RexNode> it2 = immutableList.iterator();
        while (it2.hasNext()) {
            populateEquivalences(treeMap, it2.next());
        }
        ImmutableBitSet immutableBitSet2 = immutableBitSet;
        Iterator<Integer> it3 = immutableBitSet.iterator();
        while (it3.hasNext()) {
            BitSet bitSet = (BitSet) treeMap.get(it3.next());
            if (bitSet != null) {
                immutableBitSet2 = immutableBitSet2.union(bitSet);
            }
        }
        return immutableBitSet2;
    }

    private static void populateEquivalences(Map<Integer, BitSet> map, RexNode rexNode) {
        if (rexNode.getKind() != SqlKind.EQUALS) {
            return;
        }
        List<RexNode> operands = ((RexCall) rexNode).getOperands();
        if (operands.get(0) instanceof RexInputRef) {
            RexInputRef rexInputRef = (RexInputRef) operands.get(0);
            if (operands.get(1) instanceof RexInputRef) {
                RexInputRef rexInputRef2 = (RexInputRef) operands.get(1);
                populateEquivalence(map, rexInputRef.getIndex(), rexInputRef2.getIndex());
                populateEquivalence(map, rexInputRef2.getIndex(), rexInputRef.getIndex());
            }
        }
    }

    private static void populateEquivalence(Map<Integer, BitSet> map, int i, int i2) {
        map.computeIfAbsent(Integer.valueOf(i), num -> {
            return new BitSet();
        }).set(i2);
    }

    private static <E> SqlSplittableAggFunction.Registry<E> registry(List<E> list) {
        return obj -> {
            int indexOf = list.indexOf(obj);
            if (indexOf < 0) {
                indexOf = list.size();
                list.add(obj);
            }
            return indexOf;
        };
    }
}
