/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule;
import java.util.ArrayList;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.AggregateFunction;
import org.apache.calcite.schema.FunctionParameter;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.BasicSqlType;
import org.apache.calcite.sql.type.ExplicitReturnTypeInference;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.Util;
import org.apache.kylin.measure.bitmap.BitmapCountAggFunc;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.apache.kylin.query.util.RuleUtils;

public class CountDistinctCaseWhenFunctionRule
extends AbstractAggCaseWhenFunctionRule {
    public static final CountDistinctCaseWhenFunctionRule INSTANCE = new CountDistinctCaseWhenFunctionRule(CountDistinctCaseWhenFunctionRule.operand(KapAggregateRel.class, (RelOptRuleOperand)CountDistinctCaseWhenFunctionRule.operand(KapProjectRel.class, null, input -> !AggExpressionUtil.hasAggInput((RelNode)input), (RelOptRuleOperandChildren)RelOptRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "CountDistinctCaseWhenFunctionRule");

    public CountDistinctCaseWhenFunctionRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    private boolean isCountDistinctCaseExpr(AggregateCall aggregateCall, Project inputProject) {
        if (aggregateCall.getArgList().size() != 1) {
            return false;
        }
        if (aggregateCall.getAggregation().getKind() != SqlKind.COUNT || !aggregateCall.isDistinct()) {
            return false;
        }
        int input = (Integer)aggregateCall.getArgList().get(0);
        RexNode expression = (RexNode)inputProject.getChildExps().get(input);
        if (expression.getKind() != SqlKind.CASE) {
            return false;
        }
        RexCall caseCall = (RexCall)expression;
        if (caseCall.getOperands().size() != 3) {
            return false;
        }
        return this.isSimpleCaseWhen(inputProject, (RexNode)caseCall.getOperands().get(1), (RexNode)caseCall.getOperands().get(2)) || this.isSimpleCaseWhen(inputProject, (RexNode)caseCall.getOperands().get(2), (RexNode)caseCall.getOperands().get(1));
    }

    private boolean isSimpleCaseWhen(Project inputProject, RexNode n1, RexNode n2) {
        if (RuleUtils.isNullLiteral(n1)) {
            if (n2 instanceof RexInputRef) {
                return RuleUtils.isPlainTableColumn(((RexInputRef)n2).getIndex(), inputProject.getInput(0));
            }
            if (RuleUtils.isCast(n2) && ((RexCall)n2).getOperands().get(0) instanceof RexInputRef) {
                return RuleUtils.isPlainTableColumn(((RexInputRef)((RexCall)n2).getOperands().get(0)).getIndex(), inputProject.getInput(0)) && !this.isNeedTackCast(n2);
            }
        }
        return false;
    }

    @Override
    protected boolean checkAggCaseExpression(Aggregate oldAgg, Project oldProject) {
        for (AggregateCall call : oldAgg.getAggCallList()) {
            if (!this.isCountDistinctCaseExpr(call, oldProject)) continue;
            return true;
        }
        return false;
    }

    @Override
    protected boolean isApplicableWithSumCaseRule(AggregateCall aggregateCall, Project project) {
        SqlKind aggFunction = aggregateCall.getAggregation().getKind();
        return aggFunction == SqlKind.SUM || aggFunction == SqlKind.SUM0 || aggFunction == SqlKind.MAX || aggFunction == SqlKind.MIN || aggFunction == SqlKind.COUNT && !aggregateCall.isDistinct() || this.isCountDistinctCaseExpr(aggregateCall, project) || aggregateCall.getName().equalsIgnoreCase("BITMAP_UUID");
    }

    @Override
    protected boolean isApplicableAggExpression(AggExpressionUtil.AggExpression aggExpr) {
        return aggExpr.isCountDistinctCase();
    }

    @Override
    protected SqlAggFunction getBottomAggFunc(AggregateCall aggCall) {
        return CountDistinctCaseWhenFunctionRule.createBitmapAggFunc();
    }

    @Override
    protected SqlAggFunction getTopAggFunc(AggregateCall aggCall) {
        SqlAggFunction aggFunction = aggCall.getAggregation();
        if (SqlKind.COUNT == aggCall.getAggregation().getKind()) {
            aggFunction = aggCall.isDistinct() ? CountDistinctCaseWhenFunctionRule.createBitmapCountAggFunc() : SqlStdOperatorTable.SUM0;
        }
        return aggFunction;
    }

    private static SqlAggFunction createBitmapAggFunc() {
        return CountDistinctCaseWhenFunctionRule.createCustomAggFunction("BITMAP_UUID", (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.ANY), BitmapCountAggFunc.class, null);
    }

    private static SqlAggFunction createBitmapCountAggFunc() {
        return CountDistinctCaseWhenFunctionRule.createCustomAggFunction("BITMAP_COUNT", (RelDataType)new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), BitmapCountAggFunc.class, null);
    }

    private static SqlAggFunction createCustomAggFunction(String funcName, RelDataType returnType, Class<?> customAggFuncClz, RelDataTypeFactory typeFactory) {
        SqlIdentifier sqlIdentifier = new SqlIdentifier(funcName, new SqlParserPos(1, 1));
        AggregateFunctionImpl aggFunction = AggregateFunctionImpl.create(customAggFuncClz);
        ArrayList<RelDataType> argTypes = new ArrayList<RelDataType>();
        ArrayList<Object> typeFamilies = new ArrayList<Object>();
        for (FunctionParameter o : aggFunction.getParameters()) {
            if (typeFactory == null) continue;
            RelDataType type = o.getType(typeFactory);
            argTypes.add(type);
            typeFamilies.add(Util.first((Object)type.getSqlTypeName().getFamily(), (Object)SqlTypeFamily.ANY));
        }
        ExplicitReturnTypeInference explicitReturnTypeInference = null;
        if (returnType != null) {
            explicitReturnTypeInference = ReturnTypes.explicit((RelDataType)returnType);
        }
        return new SqlUserDefinedAggFunction(sqlIdentifier, (SqlReturnTypeInference)explicitReturnTypeInference, InferTypes.explicit(argTypes), (SqlOperandTypeChecker)OperandTypes.family(typeFamilies), (AggregateFunction)aggFunction, false, false, typeFactory);
    }

    @Override
    protected String getBottomAggPrefix() {
        return "COUNT_DISTINCT_CASE$";
    }

    @Override
    protected boolean isValidAggColumnExpr(RexNode rexNode) {
        return !RuleUtils.isNullLiteral(rexNode);
    }

    @Override
    protected boolean isNeedTackCast(RexNode rexNode) {
        if (!RuleUtils.isCast(rexNode)) {
            return false;
        }
        return !SqlTypeUtil.canCastFrom((RelDataType)rexNode.getType(), (RelDataType)((RexNode)((RexCall)rexNode).getOperands().get(0)).getType(), (boolean)false);
    }
}

