package io.dingodb.calcite.rule.logical;

import com.google.common.collect.ImmutableSet;
import io.dingodb.calcite.rel.logical.LogicalReduceAggregate;
import io.dingodb.calcite.rel.logical.LogicalRelOp;
import io.dingodb.calcite.rule.logical.ImmutableLogicalSplitAggregateRule;
import io.dingodb.common.util.Utils;
import io.dingodb.exec.expr.DingoCompileContext;
import io.dingodb.expr.rel.op.RelOpBuilder;
import io.dingodb.expr.runtime.expr.Expr;
import io.dingodb.expr.runtime.expr.Exprs;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.immutables.value.Value;

@Value.Enclosing
/* loaded from: input_file:io/dingodb/calcite/rule/logical/LogicalSplitAggregateRule.class */
public class LogicalSplitAggregateRule extends RelRule<Config> implements SubstitutionRule {
    private static final Set<SqlKind> supportedAggregations = ImmutableSet.of(SqlKind.COUNT, SqlKind.SUM, SqlKind.SUM0, SqlKind.MAX, SqlKind.MIN, SqlKind.SINGLE_VALUE, new SqlKind[0]);

    @Value.Immutable
    /* loaded from: input_file:io/dingodb/calcite/rule/logical/LogicalSplitAggregateRule$Config.class */
    public interface Config extends RelRule.Config {
        public static final Config DEFAULT = ImmutableLogicalSplitAggregateRule.Config.builder().operandSupplier(operandBuilder -> {
            return operandBuilder.operand(LogicalAggregate.class).predicate(LogicalSplitAggregateRule::match).anyInputs();
        }).description("LogicalSplitAggregateRule").build();

        @Override // org.apache.calcite.plan.RelRule.Config
        default LogicalSplitAggregateRule toRule() {
            return new LogicalSplitAggregateRule(this);
        }
    }

    protected LogicalSplitAggregateRule(Config config) {
        super(config);
    }

    public static boolean match(LogicalAggregate logicalAggregate) {
        return logicalAggregate.getAggCallList().stream().noneMatch(aggregateCall -> {
            SqlKind kind = aggregateCall.getAggregation().getKind();
            if (supportedAggregations.contains(kind)) {
                return aggregateCall.isDistinct() && (kind == SqlKind.COUNT || kind == SqlKind.SUM);
            }
            return true;
        });
    }

    static Expr getAgg(AggregateCall aggregateCall) {
        SqlKind kind = aggregateCall.getAggregation().getKind();
        List<Integer> argList = aggregateCall.getArgList();
        if (argList.isEmpty() && kind == SqlKind.COUNT) {
            return Exprs.op(Exprs.COUNT_ALL_AGG);
        }
        Expr createTupleVar = DingoCompileContext.createTupleVar(((Integer) Utils.sole(argList)).intValue());
        switch (kind) {
            case COUNT:
                return Exprs.op(Exprs.COUNT_AGG, createTupleVar);
            case SUM:
                return Exprs.op(Exprs.SUM_AGG, createTupleVar);
            case SUM0:
                return Exprs.op(Exprs.SUM0_AGG, createTupleVar);
            case MAX:
                return Exprs.op(Exprs.MAX_AGG, createTupleVar);
            case MIN:
                return Exprs.op(Exprs.MIN_AGG, createTupleVar);
            case SINGLE_VALUE:
                return Exprs.op(Exprs.SINGLE_VALUE_AGG, createTupleVar);
            default:
                throw new UnsupportedOperationException("Unsupported aggregation function \"" + kind + "\".");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r7v0, types: [io.dingodb.expr.rel.RelOp] */
    /* JADX WARN: Type inference failed for: r7v1, types: [io.dingodb.expr.rel.RelOp] */
    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        LogicalAggregate logicalAggregate = (LogicalAggregate) relOptRuleCall.rel(0);
        for (AggregateCall aggregateCall : logicalAggregate.getAggCallList()) {
            SqlKind kind = aggregateCall.getAggregation().getKind();
            if (kind == SqlKind.SUM || kind == SqlKind.SUM0) {
                if (aggregateCall.type.getFamily() != SqlTypeFamily.NUMERIC) {
                    throw new IllegalArgumentException("Aggregation function \"" + kind + "\" requires numerical input but \"" + aggregateCall.type + "\" was given.");
                }
            }
        }
        int[] array = logicalAggregate.getGroupSet().asList().stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray();
        Expr[] exprArr = (Expr[]) logicalAggregate.getAggCallList().stream().map(LogicalSplitAggregateRule::getAgg).toArray(i -> {
            return new Expr[i];
        });
        Object build = array.length == 0 ? RelOpBuilder.builder().agg(exprArr).build() : RelOpBuilder.builder().agg(array, exprArr).build();
        relOptRuleCall.transformTo(new LogicalReduceAggregate(logicalAggregate.getCluster(), logicalAggregate.getTraitSet(), logicalAggregate.getHints(), new LogicalRelOp(logicalAggregate.getCluster(), logicalAggregate.getTraitSet(), logicalAggregate.getHints(), logicalAggregate.getInput(), logicalAggregate.getRowType(), build, null), build, logicalAggregate.getInput().getRowType()));
        relOptRuleCall.getPlanner().prune(logicalAggregate);
    }
}
