package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.calcite.shaded.com.google.common.collect.Lists;

/* loaded from: input_file:org/apache/calcite/rel/rules/AggregateUnionTransposeRule.class */
public class AggregateUnionTransposeRule extends RelOptRule {
    public static final AggregateUnionTransposeRule INSTANCE = new AggregateUnionTransposeRule(LogicalAggregate.class, LogicalUnion.class, RelFactories.LOGICAL_BUILDER);
    private static final Map<Class<? extends SqlAggFunction>, Boolean> SUPPORTED_AGGREGATES = new IdentityHashMap();

    public AggregateUnionTransposeRule(Class<? extends Aggregate> cls, Class<? extends Union> cls2, RelBuilderFactory relBuilderFactory) {
        super(operand(cls, operand(cls2, any()), new RelOptRuleOperand[0]), relBuilderFactory, null);
    }

    @Deprecated
    public AggregateUnionTransposeRule(Class<? extends Aggregate> cls, RelFactories.AggregateFactory aggregateFactory, Class<? extends Union> cls2, RelFactories.SetOpFactory setOpFactory) {
        this(cls, cls2, RelBuilder.proto(aggregateFactory, setOpFactory));
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        Union union = (Union) relOptRuleCall.rel(1);
        if (union.all) {
            List<AggregateCall> transformAggCalls = transformAggCalls(aggregate.copy(aggregate.getTraitSet(), aggregate.getInput(), false, aggregate.getGroupSet(), null, aggregate.getAggCallList()), aggregate.getGroupSet().cardinality(), aggregate.getAggCallList());
            if (transformAggCalls == null) {
                return;
            }
            RelBuilder builder = relOptRuleCall.builder();
            int i = 0;
            RelMetadataQuery metadataQuery = relOptRuleCall.getMetadataQuery();
            for (RelNode relNode : union.getInputs()) {
                boolean areColumnsDefinitelyUnique = RelMdUtil.areColumnsDefinitelyUnique(metadataQuery, relNode, aggregate.getGroupSet());
                builder.push(relNode);
                if (!areColumnsDefinitelyUnique) {
                    i++;
                    builder.aggregate(builder.groupKey(aggregate.getGroupSet(), null), aggregate.getAggCallList());
                }
            }
            if (i == 0) {
                return;
            }
            builder.union(true, union.getInputs().size());
            builder.aggregate(builder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()), transformAggCalls);
            relOptRuleCall.transformTo(builder.build());
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<AggregateCall> transformAggCalls(RelNode relNode, int i, List<AggregateCall> list) {
        SqlAggFunction aggregation;
        RelDataType type;
        ArrayList newArrayList = Lists.newArrayList();
        for (Ord ord : Ord.zip((List) list)) {
            AggregateCall aggregateCall = (AggregateCall) ord.e;
            if (aggregateCall.isDistinct() || !SUPPORTED_AGGREGATES.containsKey(aggregateCall.getAggregation().getClass())) {
                return null;
            }
            if (aggregateCall.getAggregation() == SqlStdOperatorTable.COUNT) {
                aggregation = SqlStdOperatorTable.SUM0;
                type = null;
            } else {
                aggregation = aggregateCall.getAggregation();
                type = aggregateCall.getType();
            }
            newArrayList.add(AggregateCall.create(aggregation, aggregateCall.isDistinct(), ImmutableList.of(Integer.valueOf(i + ord.i)), -1, i, relNode, type, aggregateCall.getName()));
        }
        return newArrayList;
    }

    static {
        SUPPORTED_AGGREGATES.put(SqlMinMaxAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlCountAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumEmptyIsZeroAggFunction.class, true);
    }
}
