package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataType;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexInputRef;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.schema.FunctionParameter;
import org.apache.kylin.job.shaded.org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlAggFunction;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlIdentifier;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlKind;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.BasicSqlType;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.InferTypes;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.OperandTypes;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.ReturnTypes;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.util.Util;
import org.apache.kylin.measure.bitmap.BitmapCountAggFunc;
import org.apache.kylin.metadata.model.FunctionDesc;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.apache.kylin.query.util.RuleUtils;

/* loaded from: input_file:io/kyligence/kap/query/optrule/CountDistinctCaseWhenFunctionRule.class */
public class CountDistinctCaseWhenFunctionRule extends AbstractAggCaseWhenFunctionRule {
    public static final CountDistinctCaseWhenFunctionRule INSTANCE = new CountDistinctCaseWhenFunctionRule(operand(KapAggregateRel.class, operand(KapProjectRel.class, null, kapProjectRel -> {
        return !AggExpressionUtil.hasAggInput(kapProjectRel);
    }, RelOptRule.any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "CountDistinctCaseWhenFunctionRule");

    public CountDistinctCaseWhenFunctionRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    private boolean isCountDistinctCaseExpr(AggregateCall aggregateCall, Project project) {
        if (aggregateCall.getArgList().size() != 1 || aggregateCall.getAggregation().getKind() != SqlKind.COUNT || !aggregateCall.isDistinct()) {
            return false;
        }
        RexNode rexNode = project.getChildExps().get(aggregateCall.getArgList().get(0).intValue());
        if (rexNode.getKind() != SqlKind.CASE) {
            return false;
        }
        RexCall rexCall = (RexCall) rexNode;
        if (rexCall.getOperands().size() != 3) {
            return false;
        }
        return isSimpleCaseWhen(project, rexCall.getOperands().get(1), rexCall.getOperands().get(2)) || isSimpleCaseWhen(project, rexCall.getOperands().get(2), rexCall.getOperands().get(1));
    }

    private boolean isSimpleCaseWhen(Project project, RexNode rexNode, RexNode rexNode2) {
        if (RuleUtils.isNullLiteral(rexNode)) {
            return rexNode2 instanceof RexInputRef ? RuleUtils.isPlainTableColumn(((RexInputRef) rexNode2).getIndex(), project.getInput(0)) : RuleUtils.isCast(rexNode2) && (((RexCall) rexNode2).getOperands().get(0) instanceof RexInputRef) && RuleUtils.isPlainTableColumn(((RexInputRef) ((RexCall) rexNode2).getOperands().get(0)).getIndex(), project.getInput(0)) && !isNeedTackCast(rexNode2);
        }
        return false;
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected boolean checkAggCaseExpression(Aggregate aggregate, Project project) {
        Iterator<AggregateCall> it2 = aggregate.getAggCallList().iterator();
        while (it2.hasNext()) {
            if (isCountDistinctCaseExpr(it2.next(), project)) {
                return true;
            }
        }
        return false;
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected boolean isApplicableWithSumCaseRule(AggregateCall aggregateCall, Project project) {
        SqlKind kind = aggregateCall.getAggregation().getKind();
        return kind == SqlKind.SUM || kind == SqlKind.SUM0 || kind == SqlKind.MAX || kind == SqlKind.MIN || (kind == SqlKind.COUNT && !aggregateCall.isDistinct()) || isCountDistinctCaseExpr(aggregateCall, project) || aggregateCall.getName().equalsIgnoreCase(FunctionDesc.FUNC_BITMAP_UUID);
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected boolean isApplicableAggExpression(AggExpressionUtil.AggExpression aggExpression) {
        return aggExpression.isCountDistinctCase();
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected SqlAggFunction getBottomAggFunc(AggregateCall aggregateCall) {
        return createBitmapAggFunc();
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected SqlAggFunction getTopAggFunc(AggregateCall aggregateCall) {
        SqlAggFunction aggregation = aggregateCall.getAggregation();
        if (SqlKind.COUNT == aggregateCall.getAggregation().getKind()) {
            aggregation = aggregateCall.isDistinct() ? createBitmapCountAggFunc() : SqlStdOperatorTable.SUM0;
        }
        return aggregation;
    }

    private static SqlAggFunction createBitmapAggFunc() {
        return createCustomAggFunction(FunctionDesc.FUNC_BITMAP_UUID, new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.ANY), BitmapCountAggFunc.class, null);
    }

    private static SqlAggFunction createBitmapCountAggFunc() {
        return createCustomAggFunction(FunctionDesc.FUNC_BITMAP_COUNT, new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), BitmapCountAggFunc.class, null);
    }

    private static SqlAggFunction createCustomAggFunction(String str, RelDataType relDataType, Class<?> cls, RelDataTypeFactory relDataTypeFactory) {
        SqlIdentifier sqlIdentifier = new SqlIdentifier(str, new SqlParserPos(1, 1));
        AggregateFunctionImpl create = AggregateFunctionImpl.create(cls);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (FunctionParameter functionParameter : create.getParameters()) {
            if (relDataTypeFactory != null) {
                RelDataType type = functionParameter.getType(relDataTypeFactory);
                arrayList.add(type);
                arrayList2.add(Util.first(type.getSqlTypeName().getFamily(), SqlTypeFamily.ANY));
            }
        }
        return new SqlUserDefinedAggFunction(sqlIdentifier, relDataType != null ? ReturnTypes.explicit(relDataType) : null, InferTypes.explicit(arrayList), OperandTypes.family(arrayList2), create, false, false, relDataTypeFactory);
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected String getBottomAggPrefix() {
        return "COUNT_DISTINCT_CASE$";
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected boolean isValidAggColumnExpr(RexNode rexNode) {
        return !RuleUtils.isNullLiteral(rexNode);
    }

    @Override // io.kyligence.kap.query.optrule.AbstractAggCaseWhenFunctionRule
    protected boolean isNeedTackCast(RexNode rexNode) {
        return RuleUtils.isCast(rexNode) && !SqlTypeUtil.canCastFrom(rexNode.getType(), ((RexCall) rexNode).getOperands().get(0).getType(), false);
    }
}
