package org.apache.druid.sql.calcite.aggregation.builtin;

import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.util.Optionality;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.planner.Calcites;

/* loaded from: input_file:org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.class */
public class SumSqlAggregator extends SimpleSqlAggregator {
    private static final SqlAggFunction DRUID_SUM = new DruidSumAggFunction();

    /* loaded from: input_file:org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator$DruidSumAggFunction.class */
    private static class DruidSumAggFunction extends SqlAggFunction {
        public DruidSumAggFunction() {
            super("SUM", null, SqlKind.SUM, ReturnTypes.AGG_SUM, null, OperandTypes.NUMERIC, SqlFunctionCategory.NUMERIC, false, false, Optionality.FORBIDDEN);
        }

        @Override // org.apache.calcite.sql.SqlAggFunction, org.apache.calcite.schema.Wrapper
        public <T> T unwrap(Class<T> cls) {
            return cls == SqlSplittableAggFunction.class ? cls.cast(DruidSumSplitter.INSTANCE) : (T) super.unwrap(cls);
        }
    }

    /* loaded from: input_file:org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator$DruidSumSplitter.class */
    private static class DruidSumSplitter extends SqlSplittableAggFunction.AbstractSumSplitter {
        public static DruidSumSplitter INSTANCE = new DruidSumSplitter();

        private DruidSumSplitter() {
        }

        @Override // org.apache.calcite.sql.SqlSplittableAggFunction.AbstractSumSplitter, org.apache.calcite.sql.SqlSplittableAggFunction
        public RexNode singleton(RexBuilder rexBuilder, RelDataType relDataType, AggregateCall aggregateCall) {
            int intValue = aggregateCall.getArgList().get(0).intValue();
            RelDataTypeField relDataTypeField = relDataType.getFieldList().get(intValue);
            RexInputRef makeInputRef = rexBuilder.makeInputRef(relDataTypeField.getType(), intValue);
            return !aggregateCall.getType().equals(relDataTypeField.getType()) ? rexBuilder.makeCast(aggregateCall.getType(), makeInputRef) : makeInputRef;
        }

        @Override // org.apache.calcite.sql.SqlSplittableAggFunction.AbstractSumSplitter
        protected SqlAggFunction getMergeAggFunctionOfTopSplit() {
            return SumSqlAggregator.DRUID_SUM;
        }
    }

    @Override // org.apache.druid.sql.calcite.aggregation.SqlAggregator
    public SqlAggFunction calciteFunction() {
        return DRUID_SUM;
    }

    @Override // org.apache.druid.sql.calcite.aggregation.builtin.SimpleSqlAggregator
    Aggregation getAggregation(String str, AggregateCall aggregateCall, ExprMacroTable exprMacroTable, String str2) {
        ColumnType columnTypeForRelDataType = Calcites.getColumnTypeForRelDataType(aggregateCall.getType());
        if (columnTypeForRelDataType == null) {
            return null;
        }
        return Aggregation.create(createSumAggregatorFactory(columnTypeForRelDataType, str, str2, exprMacroTable));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static AggregatorFactory createSumAggregatorFactory(ColumnType columnType, String str, String str2, ExprMacroTable exprMacroTable) {
        switch (columnType.getType()) {
            case LONG:
                return new LongSumAggregatorFactory(str, str2, null, exprMacroTable);
            case FLOAT:
                return new FloatSumAggregatorFactory(str, str2, null, exprMacroTable);
            case DOUBLE:
                return new DoubleSumAggregatorFactory(str, str2, null, exprMacroTable);
            default:
                throw SimpleSqlAggregator.badTypeException(str2, "SUM", columnType);
        }
    }
}
