package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;

/* loaded from: input_file:org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.class */
public class AggregateReduceFunctionsRule extends RelOptRule {
    public static final AggregateReduceFunctionsRule INSTANCE;
    static final /* synthetic */ boolean $assertionsDisabled;

    public AggregateReduceFunctionsRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory) {
        super(relOptRuleOperand, relBuilderFactory, null);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        if (super.matches(relOptRuleCall)) {
            return containsAvgStddevVarCall(((Aggregate) relOptRuleCall.rels[0]).getAggCallList());
        }
        return false;
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        reduceAggs(relOptRuleCall, (Aggregate) relOptRuleCall.rels[0]);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> list) {
        Iterator<AggregateCall> it = list.iterator();
        while (it.hasNext()) {
            if (isReducible(it.next().getAggregation().getKind())) {
                return true;
            }
        }
        return false;
    }

    private boolean isReducible(SqlKind sqlKind) {
        if (SqlKind.AVG_AGG_FUNCTIONS.contains(sqlKind)) {
            return true;
        }
        switch (sqlKind) {
            case SUM:
                return true;
            default:
                return false;
        }
    }

    private void reduceAggs(RelOptRuleCall relOptRuleCall, Aggregate aggregate) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        int groupCount = aggregate.getGroupCount();
        int indicatorCount = aggregate.getIndicatorCount();
        List<AggregateCall> newArrayList = Lists.newArrayList();
        Map<AggregateCall, RexNode> newHashMap = Maps.newHashMap();
        ArrayList newArrayList2 = Lists.newArrayList();
        for (int i = 0; i < groupCount + indicatorCount; i++) {
            newArrayList2.add(rexBuilder.makeInputRef(getFieldType(aggregate, i), i));
        }
        RelBuilder builder = relOptRuleCall.builder();
        builder.push(aggregate.getInput());
        List<RexNode> arrayList = new ArrayList<>(builder.fields());
        Iterator<AggregateCall> it = aggCallList.iterator();
        while (it.hasNext()) {
            newArrayList2.add(reduceAgg(aggregate, it.next(), newArrayList, newHashMap, arrayList));
        }
        int size = arrayList.size() - builder.peek().getRowType().getFieldCount();
        if (size > 0) {
            builder.project(arrayList, CompositeList.of((List) builder.peek().getRowType().getFieldNames(), Collections.nCopies(size, null)));
        }
        newAggregateRel(builder, aggregate, newArrayList);
        builder.project(newArrayList2, aggregate.getRowType().getFieldNames());
        relOptRuleCall.transformTo(builder.build());
    }

    private RexNode reduceAgg(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        SqlKind kind = aggregateCall.getAggregation().getKind();
        if (!isReducible(kind)) {
            return aggregate.getCluster().getRexBuilder().addAggCall(aggregateCall, aggregate.getGroupCount(), aggregate.indicator, list, map, SqlTypeUtil.projectTypes(aggregate.getInput().getRowType(), aggregateCall.getArgList()));
        }
        switch (kind) {
            case SUM:
                return reduceSum(aggregate, aggregateCall, list, map);
            case AVG:
                return reduceAvg(aggregate, aggregateCall, list, map, list2);
            case STDDEV_POP:
                return reduceStddev(aggregate, aggregateCall, true, true, list, map, list2);
            case STDDEV_SAMP:
                return reduceStddev(aggregate, aggregateCall, false, true, list, map, list2);
            case VAR_POP:
                return reduceStddev(aggregate, aggregateCall, true, false, list, map, list2);
            case VAR_SAMP:
                return reduceStddev(aggregate, aggregateCall, false, false, list, map, list2);
            default:
                throw Util.unexpected(kind);
        }
    }

    private AggregateCall createAggregateCallWithBinding(RelDataTypeFactory relDataTypeFactory, SqlAggFunction sqlAggFunction, RelDataType relDataType, Aggregate aggregate, AggregateCall aggregateCall, int i) {
        return AggregateCall.create(sqlAggFunction, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableIntList.of(i), aggregateCall.filterArg, sqlAggFunction.inferReturnType(new Aggregate.AggCallBinding(relDataTypeFactory, sqlAggFunction, ImmutableList.of(relDataType), aggregate.getGroupCount(), aggregateCall.filterArg >= 0)), (String) null);
    }

    private RexNode reduceAvg(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        int groupCount = aggregate.getGroupCount();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelDataType fieldType = getFieldType(aggregate.getInput(), aggregateCall.getArgList().get(0).intValue());
        AggregateCall create = AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null);
        AggregateCall create2 = AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null);
        RexNode addAggCall = rexBuilder.addAggCall(create, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        RexNode addAggCall2 = rexBuilder.addAggCall(create2, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        return rexBuilder.makeCast(aggregateCall.getType(), rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, rexBuilder.ensureType(aggregate.getCluster().getTypeFactory().createTypeWithNullability(aggregateCall.getType(), addAggCall.getType().isNullable()), addAggCall, true), addAggCall2));
    }

    private RexNode reduceSum(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map) {
        int groupCount = aggregate.getGroupCount();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelDataType fieldType = getFieldType(aggregate.getInput(), aggregateCall.getArgList().get(0).intValue());
        AggregateCall create = AggregateCall.create(SqlStdOperatorTable.SUM0, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, aggregateCall.name);
        AggregateCall create2 = AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate, null, null);
        RexNode addAggCall = rexBuilder.addAggCall(create, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        if (!aggregateCall.getType().isNullable()) {
            return addAggCall;
        }
        return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexBuilder.addAggCall(create2, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType)), rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.makeCast(addAggCall.getType(), rexBuilder.constantNull()), addAggCall);
    }

    private RexNode reduceStddev(Aggregate aggregate, AggregateCall aggregateCall, boolean z, boolean z2, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        RexNode makeCall;
        int groupCount = aggregate.getGroupCount();
        RelOptCluster cluster = aggregate.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        if (!$assertionsDisabled && aggregateCall.getArgList().size() != 1) {
            throw new AssertionError(aggregateCall.getArgList());
        }
        int intValue = aggregateCall.getArgList().get(0).intValue();
        RelDataType fieldType = getFieldType(aggregate.getInput(), intValue);
        RelDataType createTypeWithNullability = typeFactory.createTypeWithNullability(aggregateCall.getType(), fieldType.isNullable());
        RexNode ensureType = rexBuilder.ensureType(createTypeWithNullability, list2.get(intValue), true);
        lookupOrAdd(list2, ensureType);
        RexNode makeCall2 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, ensureType, ensureType);
        AggregateCall createAggregateCallWithBinding = createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM, makeCall2.getType(), aggregate, aggregateCall, lookupOrAdd(list2, makeCall2));
        RexNode addAggCall = rexBuilder.addAggCall(createAggregateCallWithBinding, groupCount, aggregate.indicator, list, map, ImmutableList.of(createAggregateCallWithBinding.getType()));
        AggregateCall create = AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableIntList.of(intValue), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null);
        RexNode ensureType2 = rexBuilder.ensureType(createTypeWithNullability, rexBuilder.addAggCall(create, groupCount, aggregate.indicator, list, map, ImmutableList.of(create.getType())), true);
        RexNode makeCall3 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, ensureType2, ensureType2);
        RexNode addAggCall2 = rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate, null, null), groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        RexNode makeCall4 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, addAggCall, rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, makeCall3, addAggCall2));
        if (z) {
            makeCall = addAggCall2;
        } else {
            RexLiteral makeExactLiteral = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexNode makeCast = rexBuilder.makeCast(addAggCall2.getType(), rexBuilder.constantNull());
            RexNode makeCall5 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, addAggCall2, makeExactLiteral);
            makeCall = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, addAggCall2, makeExactLiteral), makeCast, makeCall5);
        }
        RexNode makeCall6 = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, makeCall4, makeCall);
        RexNode rexNode = makeCall6;
        if (z2) {
            rexNode = rexBuilder.makeCall(SqlStdOperatorTable.POWER, makeCall6, rexBuilder.makeExactLiteral(new BigDecimal("0.5")));
        }
        return rexBuilder.makeCast(aggregateCall.getType(), rexNode);
    }

    private static <T> int lookupOrAdd(List<T> list, T t) {
        int indexOf = list.indexOf(t);
        if (indexOf == -1) {
            indexOf = list.size();
            list.add(t);
        }
        return indexOf;
    }

    protected void newAggregateRel(RelBuilder relBuilder, Aggregate aggregate, List<AggregateCall> list) {
        relBuilder.aggregate(relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()), list);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        return relNode.getRowType().getFieldList().get(i).getType();
    }

    static {
        $assertionsDisabled = !AggregateReduceFunctionsRule.class.desiredAssertionStatus();
        INSTANCE = new AggregateReduceFunctionsRule(operand(LogicalAggregate.class, any()), RelFactories.LOGICAL_BUILDER);
    }
}
