/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.relnode;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.adapter.enumerable.EnumerableAggregate;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.InvalidRelException;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.schema.AggregateFunction;
import org.apache.calcite.schema.FunctionParameter;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.kylin.common.KylinConfig;
import org.apache.kylin.metadata.model.ColumnDesc;
import org.apache.kylin.metadata.model.FunctionDesc;
import org.apache.kylin.metadata.model.MeasureDesc;
import org.apache.kylin.metadata.model.ParameterDesc;
import org.apache.kylin.metadata.model.TableDesc;
import org.apache.kylin.metadata.model.TblColRef;
import org.apache.kylin.query.relnode.ColumnRowType;
import org.apache.kylin.query.relnode.OLAPContext;
import org.apache.kylin.query.relnode.OLAPRel;
import org.apache.kylin.query.schema.OLAPTable;

public class OLAPAggregateRel
extends Aggregate
implements OLAPRel {
    private static final Map<String, String> AGGR_FUNC_MAP = new HashMap<String, String>();
    private OLAPContext context;
    private ColumnRowType columnRowType;
    private boolean afterAggregate;
    private List<AggregateCall> rewriteAggCalls;
    private List<TblColRef> groups;
    private List<FunctionDesc> aggregations;

    private static String getFuncName(AggregateCall aggCall) {
        String funcName;
        String aggName = aggCall.getAggregation().getName();
        if (aggCall.isDistinct()) {
            aggName = aggName + "_DISTINCT";
        }
        if ((funcName = AGGR_FUNC_MAP.get(aggName)) == null) {
            throw new IllegalStateException("Don't suppoprt aggregation " + aggName);
        }
        return funcName;
    }

    public OLAPAggregateRel(RelOptCluster cluster, RelTraitSet traits, RelNode child, ImmutableBitSet groupSet, List<AggregateCall> aggCalls) throws InvalidRelException {
        super(cluster, traits, child, false, groupSet, OLAPAggregateRel.asList(groupSet), aggCalls);
        Preconditions.checkArgument((this.getConvention() == OLAPRel.CONVENTION ? 1 : 0) != 0);
        this.afterAggregate = false;
        this.rewriteAggCalls = aggCalls;
        this.rowType = this.getRowType();
    }

    private static List<ImmutableBitSet> asList(ImmutableBitSet groupSet) {
        ArrayList<ImmutableBitSet> l = new ArrayList<ImmutableBitSet>(1);
        l.add(groupSet);
        return l;
    }

    public Aggregate copy(RelTraitSet traitSet, RelNode input, boolean indicator, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
        try {
            return new OLAPAggregateRel(this.getCluster(), traitSet, input, groupSet, aggCalls);
        }
        catch (InvalidRelException e) {
            throw new IllegalStateException("Can't create OLAPAggregateRel!", e);
        }
    }

    public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        return super.computeSelfCost(planner, mq).multiplyBy(0.05);
    }

    @Override
    public void implementOLAP(OLAPRel.OLAPImplementor implementor) {
        implementor.visitChild(this.getInput(), this);
        this.context = implementor.getContext();
        this.columnRowType = this.buildColumnRowType();
        this.afterAggregate = this.context.afterAggregate;
        if (!this.afterAggregate) {
            this.translateGroupBy();
            this.context.aggregations.addAll(this.aggregations);
            this.context.afterAggregate = true;
        } else {
            for (AggregateCall aggCall : this.aggCalls) {
                if (!aggCall.isDistinct()) continue;
                throw new IllegalStateException("Distinct count is only allowed in innermost sub-query.");
            }
        }
    }

    private ColumnRowType buildColumnRowType() {
        this.buildGroups();
        this.buildAggregations();
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        ArrayList<TblColRef> columns = new ArrayList<TblColRef>(this.rowType.getFieldCount());
        columns.addAll(this.groups);
        for (int i = 0; i < this.aggregations.size(); ++i) {
            FunctionDesc aggFunc = this.aggregations.get(i);
            TblColRef aggCol = null;
            if (aggFunc.needRewriteField()) {
                aggCol = this.buildRewriteColumn(aggFunc);
            } else {
                AggregateCall aggCall = this.rewriteAggCalls.get(i);
                if (!aggCall.getArgList().isEmpty()) {
                    int index = (Integer)aggCall.getArgList().get(0);
                    aggCol = inputColumnRowType.getColumnByIndex(index);
                }
            }
            columns.add(aggCol);
        }
        return new ColumnRowType(columns);
    }

    private TblColRef buildRewriteColumn(FunctionDesc aggFunc) {
        if (!aggFunc.needRewriteField()) {
            throw new IllegalStateException("buildRewriteColumn on a aggrFunc that does not need rewrite " + aggFunc);
        }
        ColumnDesc column = new ColumnDesc();
        column.setName(aggFunc.getRewriteFieldName());
        TableDesc table = this.context.firstTableScan.getOlapTable().getSourceTable();
        column.setTable(table);
        TblColRef colRef = column.getRef();
        return colRef;
    }

    private void buildGroups() {
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        this.groups = new ArrayList<TblColRef>();
        int i = this.getGroupSet().nextSetBit(0);
        while (i >= 0) {
            Set<TblColRef> columns = inputColumnRowType.getSourceColumnsByIndex(i);
            this.groups.addAll(columns);
            i = this.getGroupSet().nextSetBit(i + 1);
        }
    }

    private void buildAggregations() {
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        this.aggregations = new ArrayList<FunctionDesc>();
        for (AggregateCall aggCall : this.rewriteAggCalls) {
            int index;
            TblColRef column;
            ParameterDesc parameter = null;
            if (!aggCall.getArgList().isEmpty() && !(column = inputColumnRowType.getColumnByIndex(index = ((Integer)aggCall.getArgList().get(0)).intValue())).isInnerColumn()) {
                parameter = new ParameterDesc();
                parameter.setValue(column.getName());
                parameter.setType("column");
                parameter.setColRefs(Arrays.asList(column));
            }
            FunctionDesc aggFunc = new FunctionDesc();
            String funcName = OLAPAggregateRel.getFuncName(aggCall);
            aggFunc.setExpression(funcName);
            aggFunc.setParameter(parameter);
            this.aggregations.add(aggFunc);
        }
    }

    private void translateGroupBy() {
        this.context.groupByColumns.addAll(this.groups);
    }

    @Override
    public void implementRewrite(OLAPRel.RewriteImplementor implementor) {
        if (!this.afterAggregate) {
            this.translateAggregation();
            this.buildRewriteFieldsAndMetricsColumns();
        }
        implementor.visitChild(this, this.getInput());
        if (!this.afterAggregate && OLAPRel.RewriteImplementor.needRewrite(this.context)) {
            this.rewriteAggCalls = new ArrayList<AggregateCall>(this.aggCalls.size());
            for (int i = 0; i < this.aggCalls.size(); ++i) {
                AggregateCall aggCall = (AggregateCall)this.aggCalls.get(i);
                FunctionDesc cubeFunc = this.context.aggregations.get(i);
                if (cubeFunc.needRewrite()) {
                    aggCall = this.rewriteAggregateCall(aggCall, cubeFunc);
                }
                this.rewriteAggCalls.add(aggCall);
            }
        }
        this.rowType = this.deriveRowType();
        this.columnRowType = this.buildColumnRowType();
    }

    private void translateAggregation() {
        List measures = this.context.realization.getMeasures();
        ArrayList newAggrs = Lists.newArrayList();
        for (FunctionDesc aggFunc : this.aggregations) {
            newAggrs.add(this.findInMeasures(aggFunc, measures));
        }
        this.aggregations.clear();
        this.aggregations.addAll(newAggrs);
        this.context.aggregations.clear();
        this.context.aggregations.addAll(newAggrs);
    }

    private FunctionDesc findInMeasures(FunctionDesc aggFunc, List<MeasureDesc> measures) {
        for (MeasureDesc m : measures) {
            if (!aggFunc.equals((Object)m.getFunction())) continue;
            return m.getFunction();
        }
        return aggFunc;
    }

    private void buildRewriteFieldsAndMetricsColumns() {
        this.fillbackOptimizedColumn();
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        RelDataTypeFactory typeFactory = this.getCluster().getTypeFactory();
        for (int i = 0; i < this.aggregations.size(); ++i) {
            int index;
            AggregateCall aggCall;
            TblColRef column;
            FunctionDesc aggFunc = this.aggregations.get(i);
            if (aggFunc.isDimensionAsMetric()) {
                this.context.groupByColumns.addAll(aggFunc.getParameter().getColRefs());
                continue;
            }
            if (aggFunc.needRewriteField()) {
                String rewriteFieldName = aggFunc.getRewriteFieldName();
                RelDataType rewriteFieldType = OLAPTable.createSqlType(typeFactory, aggFunc.getRewriteFieldType(), true);
                this.context.rewriteFields.put(rewriteFieldName, rewriteFieldType);
                column = this.buildRewriteColumn(aggFunc);
                this.context.metricsColumns.add(column);
            }
            if ((aggCall = this.rewriteAggCalls.get(i)).getArgList().isEmpty() || (column = inputColumnRowType.getColumnByIndex(index = ((Integer)aggCall.getArgList().get(0)).intValue())).isInnerColumn()) continue;
            this.context.metricsColumns.add(column);
        }
    }

    private void fillbackOptimizedColumn() {
        RelDataType inputAggRow = this.getInput().getRowType();
        RelDataType outputAggRow = this.getRowType();
        if (inputAggRow.getFieldCount() != outputAggRow.getFieldCount()) {
            for (RelDataTypeField inputField : inputAggRow.getFieldList()) {
                String inputFieldName = inputField.getName();
                if (inputFieldName.startsWith("$") || outputAggRow.getField(inputFieldName, true, false) != null) continue;
                TblColRef column = this.columnRowType.getColumnByIndex(inputField.getIndex());
                this.context.metricsColumns.add(column);
            }
        }
    }

    private AggregateCall rewriteAggregateCall(AggregateCall aggCall, FunctionDesc func) {
        ArrayList newArgList = Lists.newArrayList((Iterable)aggCall.getArgList());
        if (func.needRewriteField()) {
            RelDataTypeField field = this.getInput().getRowType().getField(func.getRewriteFieldName(), true, false);
            if (newArgList.isEmpty()) {
                newArgList.add(field.getIndex());
            } else {
                newArgList.set(0, field.getIndex());
            }
        }
        RelDataType fieldType = aggCall.getType();
        SqlAggFunction newAgg = aggCall.getAggregation();
        if (func.isCount()) {
            newAgg = SqlStdOperatorTable.SUM0;
        } else if (func.getMeasureType().getRewriteCalciteAggrFunctionClass() != null) {
            newAgg = this.createCustomAggFunction(func.getExpression(), fieldType, func.getMeasureType().getRewriteCalciteAggrFunctionClass());
        }
        AggregateCall newAggCall = new AggregateCall(newAgg, false, (List)newArgList, fieldType, newAgg.getName());
        return newAggCall;
    }

    private SqlAggFunction createCustomAggFunction(String funcName, RelDataType returnType, Class<?> customAggFuncClz) {
        RelDataTypeFactory typeFactory = this.getCluster().getTypeFactory();
        SqlIdentifier sqlIdentifier = new SqlIdentifier(funcName, new SqlParserPos(1, 1));
        AggregateFunctionImpl aggFunction = AggregateFunctionImpl.create(customAggFuncClz);
        ArrayList<RelDataType> argTypes = new ArrayList<RelDataType>();
        ArrayList<Object> typeFamilies = new ArrayList<Object>();
        for (FunctionParameter o : aggFunction.getParameters()) {
            RelDataType type = o.getType(typeFactory);
            argTypes.add(type);
            typeFamilies.add(Util.first((Object)type.getSqlTypeName().getFamily(), (Object)SqlTypeFamily.ANY));
        }
        return new SqlUserDefinedAggFunction(sqlIdentifier, (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)returnType), InferTypes.explicit(argTypes), (SqlOperandTypeChecker)OperandTypes.family(typeFamilies), (AggregateFunction)aggFunction);
    }

    @Override
    public EnumerableRel implementEnumerable(List<EnumerableRel> inputs) {
        try {
            return new EnumerableAggregate(this.getCluster(), this.getCluster().traitSetOf((RelTrait)EnumerableConvention.INSTANCE), (RelNode)OLAPAggregateRel.sole(inputs), false, this.groupSet, (List)this.groupSets, this.rewriteAggCalls);
        }
        catch (InvalidRelException e) {
            throw new IllegalStateException("Can't create EnumerableAggregate!", e);
        }
    }

    @Override
    public OLAPContext getContext() {
        return this.context;
    }

    @Override
    public ColumnRowType getColumnRowType() {
        return this.columnRowType;
    }

    @Override
    public boolean hasSubQuery() {
        OLAPRel olapChild = (OLAPRel)this.getInput();
        return olapChild.hasSubQuery();
    }

    @Override
    public RelTraitSet replaceTraitSet(RelTrait trait) {
        RelTraitSet oldTraitSet = this.traitSet;
        this.traitSet = this.traitSet.replace(trait);
        return oldTraitSet;
    }

    static {
        AGGR_FUNC_MAP.put("SUM", "SUM");
        AGGR_FUNC_MAP.put("$SUM0", "SUM");
        AGGR_FUNC_MAP.put("COUNT", "COUNT");
        AGGR_FUNC_MAP.put("COUNT_DISTINCT", "COUNT_DISTINCT");
        AGGR_FUNC_MAP.put("MAX", "MAX");
        AGGR_FUNC_MAP.put("MIN", "MIN");
        for (String customAggrFunc : KylinConfig.getInstanceFromEnv().getCubeCustomMeasureTypes().keySet()) {
            AGGR_FUNC_MAP.put(customAggrFunc.trim().toUpperCase(), customAggrFunc.trim().toUpperCase());
        }
    }
}

