package org.apache.calcite.rel.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

/* loaded from: input_file:WEB-INF/lib/calcite-core-1.16.0-kylin-r3.jar:org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.class */
public final class AggregateExpandDistinctAggregatesRule extends RelOptRule {
    public static final AggregateExpandDistinctAggregatesRule INSTANCE;
    public static final AggregateExpandDistinctAggregatesRule JOIN;
    public final boolean useGroupingSets;
    static final /* synthetic */ boolean $assertionsDisabled;

    public AggregateExpandDistinctAggregatesRule(Class<? extends Aggregate> cls, boolean z, RelBuilderFactory relBuilderFactory) {
        super(operand(cls, any()), relBuilderFactory, null);
        this.useGroupingSets = z;
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> cls, boolean z, RelFactories.JoinFactory joinFactory) {
        this(cls, z, RelBuilder.proto(Contexts.of(joinFactory)));
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> cls, RelFactories.JoinFactory joinFactory) {
        this((Class<? extends Aggregate>) cls, false, RelBuilder.proto(Contexts.of(joinFactory)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        if (aggregate.containsDistinctCall()) {
            int i = 0;
            int i2 = 0;
            int i3 = 0;
            Set<Pair<List<Integer>, Integer>> linkedHashSet = new LinkedHashSet<>();
            for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
                if (aggregateCall.filterArg >= 0) {
                    i2++;
                }
                if (!aggregateCall.isDistinct()) {
                    i++;
                    switch (aggregateCall.getAggregation().getKind()) {
                        case COUNT:
                        case SUM:
                        case SUM0:
                        case MIN:
                        case MAX:
                            break;
                        default:
                            i3++;
                            break;
                    }
                } else {
                    linkedHashSet.add(Pair.of(aggregateCall.getArgList(), Integer.valueOf(aggregateCall.filterArg)));
                }
            }
            int size = aggregate.getAggCallList().size() - i;
            Preconditions.checkState(linkedHashSet.size() > 0, "containsDistinctCall lied");
            if (i == 0 && linkedHashSet.size() == 1 && aggregate.getGroupType() == Aggregate.Group.SIMPLE) {
                Pair pair = (Pair) Iterables.getOnlyElement(linkedHashSet);
                RelBuilder builder = relOptRuleCall.builder();
                convertMonopole(builder, aggregate, (List) pair.left, ((Integer) pair.right).intValue());
                relOptRuleCall.transformTo(builder.build());
                return;
            }
            if (this.useGroupingSets) {
                rewriteUsingGroupingSets(relOptRuleCall, aggregate);
                return;
            }
            if (size == 1 && i2 == 0 && i3 == 0 && i > 0) {
                RelBuilder builder2 = relOptRuleCall.builder();
                convertSingletonDistinct(builder2, aggregate, linkedHashSet);
                relOptRuleCall.transformTo(builder2.build());
                return;
            }
            List<RelDataTypeField> fieldList = aggregate.getRowType().getFieldList();
            List<RexInputRef> arrayList = new ArrayList<>();
            List<String> fieldNames = aggregate.getRowType().getFieldNames();
            ImmutableBitSet groupSet = aggregate.getGroupSet();
            int groupCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
            Iterator<Integer> it = Util.range(groupCount).iterator();
            while (it.hasNext()) {
                arrayList.add(RexInputRef.of(it.next().intValue(), fieldList));
            }
            List<AggregateCall> arrayList2 = new ArrayList<>();
            int i4 = -1;
            for (AggregateCall aggregateCall2 : aggregate.getAggCallList()) {
                i4++;
                if (aggregateCall2.isDistinct()) {
                    arrayList.add(null);
                } else {
                    arrayList.add(new RexInputRef(groupCount + arrayList2.size(), fieldList.get(groupCount + i4).getType()));
                    arrayList2.add(aggregateCall2);
                }
            }
            RelBuilder builder3 = relOptRuleCall.builder();
            builder3.push(aggregate.getInput());
            int i5 = 0;
            if (!arrayList2.isEmpty()) {
                builder3.aggregate(builder3.groupKey(groupSet, aggregate.getGroupSets()), arrayList2);
                i5 = 0 + 1;
            }
            for (Pair<List<Integer>, Integer> pair2 : linkedHashSet) {
                int i6 = i5;
                i5++;
                doRewrite(builder3, aggregate, i6, (List) pair2.left, ((Integer) pair2.right).intValue(), arrayList);
            }
            builder3.project(arrayList, fieldNames);
            relOptRuleCall.transformTo(builder3.build());
        }
    }

    private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> set) {
        AggregateCall create;
        Preconditions.checkArgument(set.size() == 1);
        relBuilder.push(aggregate.getInput());
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        TreeSet treeSet = new TreeSet();
        treeSet.addAll(aggregate.getGroupSet().asList());
        Iterator<AggregateCall> it = aggCallList.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            AggregateCall next = it.next();
            if (next.isDistinct()) {
                treeSet.addAll(next.getArgList());
                break;
            }
        }
        ArrayList arrayList = new ArrayList();
        for (AggregateCall aggregateCall : aggCallList) {
            if (!aggregateCall.isDistinct()) {
                arrayList.add(AggregateCall.create(aggregateCall.getAggregation(), false, aggregateCall.isApproximate(), aggregateCall.getArgList(), -1, ImmutableBitSet.of(treeSet).cardinality(), relBuilder.peek(), null, aggregateCall.name));
            }
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(treeSet), null, arrayList));
        ArrayList newArrayList = Lists.newArrayList();
        int i = 0;
        for (AggregateCall aggregateCall2 : aggCallList) {
            if (aggregateCall2.isDistinct()) {
                ArrayList arrayList2 = new ArrayList();
                Iterator<Integer> it2 = aggregateCall2.getArgList().iterator();
                while (it2.hasNext()) {
                    arrayList2.add(Integer.valueOf(treeSet.headSet(Integer.valueOf(it2.next().intValue())).size()));
                }
                create = AggregateCall.create(aggregateCall2.getAggregation(), false, aggregateCall2.isApproximate(), arrayList2, -1, groupSet.cardinality(), relBuilder.peek(), aggregateCall2.getType(), aggregateCall2.name);
            } else {
                ArrayList newArrayList2 = Lists.newArrayList(Integer.valueOf(treeSet.size() + i));
                create = aggregateCall2.getAggregation().getKind() == SqlKind.COUNT ? AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggregateCall2.isApproximate(), newArrayList2, -1, groupSet.cardinality(), relBuilder.peek(), aggregateCall2.getType(), aggregateCall2.getName()) : AggregateCall.create(aggregateCall2.getAggregation(), false, aggregateCall2.isApproximate(), newArrayList2, -1, groupSet.cardinality(), relBuilder.peek(), aggregateCall2.getType(), aggregateCall2.name);
                i++;
            }
            newArrayList.add(create);
        }
        HashSet hashSet = new HashSet();
        int i2 = 0;
        Iterator it3 = treeSet.iterator();
        while (it3.hasNext()) {
            if (groupSet.get(((Integer) it3.next()).intValue())) {
                hashSet.add(Integer.valueOf(i2));
            }
            i2++;
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(hashSet), null, newArrayList));
        return relBuilder;
    }

    private void rewriteUsingGroupingSets(RelOptRuleCall relOptRuleCall, Aggregate aggregate) {
        SqlAggFunction aggregation;
        List<Integer> remap;
        int intValue;
        TreeSet treeSet = new TreeSet(ImmutableBitSet.ORDERING);
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.isDistinct()) {
                treeSet.add(ImmutableBitSet.of(aggregateCall.getArgList()).setIf(aggregateCall.filterArg, aggregateCall.filterArg >= 0).union(aggregate.getGroupSet()));
            } else {
                treeSet.add(aggregate.getGroupSet());
            }
        }
        ImmutableList<ImmutableBitSet> copyOf = ImmutableList.copyOf((Collection) treeSet);
        ImmutableBitSet union = ImmutableBitSet.union(copyOf);
        ArrayList arrayList = new ArrayList();
        for (Pair<AggregateCall, String> pair : aggregate.getNamedAggCalls()) {
            if (!pair.left.isDistinct()) {
                arrayList.add(pair.left.adaptTo(aggregate.getInput(), pair.left.getArgList(), pair.left.filterArg, aggregate.getGroupCount(), union.cardinality()).rename(pair.right));
            }
        }
        RelBuilder builder = relOptRuleCall.builder();
        builder.push(aggregate.getInput());
        int cardinality = union.cardinality();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int size = cardinality + arrayList.size();
        arrayList.add(AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, ImmutableIntList.copyOf(union), -1, copyOf.size(), builder.peek(), null, "$g"));
        for (Ord ord : Ord.zip((List) copyOf)) {
            linkedHashMap.put(ord.e, Integer.valueOf(size + ord.i));
        }
        builder.aggregate(builder.groupKey(union, copyOf), (List<AggregateCall>) arrayList);
        RelNode peek = builder.peek();
        if (!linkedHashMap.isEmpty()) {
            ArrayList arrayList2 = new ArrayList(builder.fields());
            RexNode rexNode = (RexNode) arrayList2.remove(arrayList2.size() - 1);
            Iterator it = linkedHashMap.entrySet().iterator();
            while (it.hasNext()) {
                long groupValue = groupValue(union, (ImmutableBitSet) ((Map.Entry) it.next()).getKey());
                arrayList2.add(builder.alias(builder.equals(rexNode, builder.literal(Long.valueOf(groupValue))), "$g_" + groupValue));
            }
            builder.project(arrayList2);
        }
        int i = cardinality;
        ArrayList arrayList3 = new ArrayList();
        for (AggregateCall aggregateCall2 : aggregate.getAggCallList()) {
            if (aggregateCall2.isDistinct()) {
                aggregation = aggregateCall2.getAggregation();
                remap = remap(union, aggregateCall2.getArgList());
                intValue = ((Integer) linkedHashMap.get(ImmutableBitSet.of(aggregateCall2.getArgList()).setIf(aggregateCall2.filterArg, aggregateCall2.filterArg >= 0).union(aggregate.getGroupSet()))).intValue();
            } else {
                aggregation = SqlStdOperatorTable.MIN;
                int i2 = i;
                i++;
                remap = ImmutableIntList.of(i2);
                intValue = ((Integer) linkedHashMap.get(aggregate.getGroupSet())).intValue();
            }
            arrayList3.add(AggregateCall.create(aggregation, false, aggregateCall2.isApproximate(), remap, intValue, aggregate.getGroupCount(), peek, null, aggregateCall2.name));
        }
        builder.aggregate(builder.groupKey(remap(union, aggregate.getGroupSet()), remap(union, (Iterable<ImmutableBitSet>) aggregate.getGroupSets())), (List<AggregateCall>) arrayList3);
        builder.convert(aggregate.getRowType(), true);
        relOptRuleCall.transformTo(builder.build());
    }

    private static long groupValue(ImmutableBitSet immutableBitSet, ImmutableBitSet immutableBitSet2) {
        long j = 0;
        long cardinality = 1 << (immutableBitSet.cardinality() - 1);
        if (!$assertionsDisabled && !immutableBitSet.contains(immutableBitSet2)) {
            throw new AssertionError();
        }
        Iterator<Integer> it = immutableBitSet.iterator();
        while (it.hasNext()) {
            if (!immutableBitSet2.get(it.next().intValue())) {
                j |= cardinality;
            }
            cardinality >>= 1;
        }
        return j;
    }

    private static ImmutableBitSet remap(ImmutableBitSet immutableBitSet, ImmutableBitSet immutableBitSet2) {
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        Iterator<Integer> it = immutableBitSet2.iterator();
        while (it.hasNext()) {
            builder.set(remap(immutableBitSet, it.next().intValue()));
        }
        return builder.build();
    }

    private static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet immutableBitSet, Iterable<ImmutableBitSet> iterable) {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<ImmutableBitSet> it = iterable.iterator();
        while (it.hasNext()) {
            builder.add((ImmutableList.Builder) remap(immutableBitSet, it.next()));
        }
        return builder.build();
    }

    private static List<Integer> remap(ImmutableBitSet immutableBitSet, List<Integer> list) {
        ImmutableIntList of = ImmutableIntList.of();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            of = of.append(remap(immutableBitSet, it.next().intValue()));
        }
        return of;
    }

    private static int remap(ImmutableBitSet immutableBitSet, int i) {
        if (i < 0) {
            return -1;
        }
        return immutableBitSet.indexOf(i);
    }

    private RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate, List<Integer> list, int i) {
        HashMap hashMap = new HashMap();
        createSelectDistinct(relBuilder, aggregate, list, i, hashMap);
        ArrayList newArrayList = Lists.newArrayList(aggregate.getAggCallList());
        rewriteAggCalls(newArrayList, list, hashMap);
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.range(aggregate.getGroupSet().cardinality()), null, newArrayList));
        return relBuilder;
    }

    private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int i, List<Integer> list, int i2, List<RexInputRef> list2) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List<RelDataTypeField> fieldList = i == 0 ? null : relBuilder.peek().getRowType().getFieldList();
        HashMap hashMap = new HashMap();
        createSelectDistinct(relBuilder, aggregate, list, i2, hashMap);
        ArrayList arrayList = new ArrayList();
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        int groupCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        int i3 = groupCount - 1;
        for (AggregateCall aggregateCall : aggCallList) {
            i3++;
            if (aggregateCall.isDistinct() && aggregateCall.getArgList().equals(list)) {
                int size = aggregateCall.getArgList().size();
                ArrayList arrayList2 = new ArrayList(size);
                for (int i4 = 0; i4 < size; i4++) {
                    arrayList2.add(hashMap.get(aggregateCall.getArgList().get(i4)));
                }
                AggregateCall create = AggregateCall.create(aggregateCall.getAggregation(), false, aggregateCall.isApproximate(), (List<Integer>) arrayList2, aggregateCall.filterArg >= 0 ? hashMap.get(Integer.valueOf(aggregateCall.filterArg)).intValue() : -1, aggregateCall.getType(), aggregateCall.getName());
                if (!$assertionsDisabled && list2.get(i3) != null) {
                    throw new AssertionError();
                }
                if (i == 0) {
                    list2.set(i3, new RexInputRef(groupCount + arrayList.size(), create.getType()));
                } else {
                    list2.set(i3, new RexInputRef(fieldList.size() + groupCount + arrayList.size(), create.getType()));
                }
                arrayList.add(create);
            }
        }
        HashMap hashMap2 = new HashMap();
        Iterator<Integer> it = aggregate.getGroupSet().iterator();
        while (it.hasNext()) {
            hashMap2.put(it.next(), Integer.valueOf(hashMap2.size()));
        }
        ImmutableBitSet permute = aggregate.getGroupSet().permute(hashMap2);
        if (!$assertionsDisabled && !permute.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()))) {
            throw new AssertionError();
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, permute, aggregate.indicator ? ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.getGroupSets(), hashMap2)) : null, arrayList));
        if (i == 0) {
            return;
        }
        List<RelDataTypeField> fieldList2 = relBuilder.peek().getRowType().getFieldList();
        ArrayList newArrayList = Lists.newArrayList();
        for (int i5 = 0; i5 < groupCount; i5++) {
            newArrayList.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i5, fieldList), new RexInputRef(fieldList.size() + i5, fieldList2.get(i5).getType())));
        }
        relBuilder.join(JoinRelType.INNER, newArrayList);
    }

    private static void rewriteAggCalls(List<AggregateCall> list, List<Integer> list2, Map<Integer, Integer> map) {
        for (int i = 0; i < list.size(); i++) {
            AggregateCall aggregateCall = list.get(i);
            if (aggregateCall.isDistinct() && aggregateCall.getArgList().equals(list2)) {
                int size = aggregateCall.getArgList().size();
                ArrayList arrayList = new ArrayList(size);
                for (int i2 = 0; i2 < size; i2++) {
                    arrayList.add(map.get(aggregateCall.getArgList().get(i2)));
                }
                list.set(i, AggregateCall.create(aggregateCall.getAggregation(), false, aggregateCall.isApproximate(), (List<Integer>) arrayList, -1, aggregateCall.getType(), aggregateCall.getName()));
            }
        }
    }

    private RelBuilder createSelectDistinct(RelBuilder relBuilder, Aggregate aggregate, List<Integer> list, int i, Map<Integer, Integer> map) {
        relBuilder.push(aggregate.getInput());
        ArrayList arrayList = new ArrayList();
        List<RelDataTypeField> fieldList = relBuilder.peek().getRowType().getFieldList();
        Iterator<Integer> it = aggregate.getGroupSet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            map.put(Integer.valueOf(intValue), Integer.valueOf(arrayList.size()));
            arrayList.add(RexInputRef.of2(intValue, fieldList));
        }
        for (Integer num : list) {
            if (i >= 0) {
                RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
                RexInputRef of = RexInputRef.of(i, fieldList);
                Pair<RexNode, String> of2 = RexInputRef.of2(num.intValue(), fieldList);
                RexNode makeCall = rexBuilder.makeCall(SqlStdOperatorTable.CASE, of, of2.left, rexBuilder.ensureType(of2.left.getType(), rexBuilder.makeCast(of2.left.getType(), rexBuilder.constantNull()), true));
                map.put(num, Integer.valueOf(arrayList.size()));
                arrayList.add(Pair.of(makeCall, "i$" + of2.right));
            } else if (map.get(num) == null) {
                map.put(num, Integer.valueOf(arrayList.size()));
                arrayList.add(RexInputRef.of2(num.intValue(), fieldList));
            }
        }
        relBuilder.project(Pair.left((List) arrayList), Pair.right((List) arrayList));
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.range(arrayList.size()), null, ImmutableList.of()));
        return relBuilder;
    }

    static {
        $assertionsDisabled = !AggregateExpandDistinctAggregatesRule.class.desiredAssertionStatus();
        INSTANCE = new AggregateExpandDistinctAggregatesRule((Class<? extends Aggregate>) LogicalAggregate.class, true, RelFactories.LOGICAL_BUILDER);
        JOIN = new AggregateExpandDistinctAggregatesRule((Class<? extends Aggregate>) LogicalAggregate.class, false, RelFactories.LOGICAL_BUILDER);
    }
}
