package org.apache.druid.sql.calcite.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.druid.sql.calcite.planner.Calcites;

/* loaded from: input_file:org/apache/druid/sql/calcite/rule/CaseFilteredAggregatorRule.class */
public class CaseFilteredAggregatorRule extends RelOptRule {
    private static final CaseFilteredAggregatorRule INSTANCE = new CaseFilteredAggregatorRule();

    private CaseFilteredAggregatorRule() {
        super(operand(Aggregate.class, operand(Project.class, any()), new RelOptRuleOperand[0]));
    }

    public static CaseFilteredAggregatorRule instance() {
        return INSTANCE;
    }

    public boolean matches(RelOptRuleCall relOptRuleCall) {
        Aggregate rel = relOptRuleCall.rel(0);
        Project rel2 = relOptRuleCall.rel(1);
        if (rel.indicator || rel.getGroupSets().size() != 1) {
            return false;
        }
        for (AggregateCall aggregateCall : rel.getAggCallList()) {
            if (isOneArgAggregateCall(aggregateCall) && isThreeArgCase((RexNode) rel2.getChildExps().get(((Integer) Iterables.getOnlyElement(aggregateCall.getArgList())).intValue()))) {
                return true;
            }
        }
        return false;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate rel = relOptRuleCall.rel(0);
        Project rel2 = relOptRuleCall.rel(1);
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        ArrayList arrayList = new ArrayList(rel.getAggCallList().size());
        ArrayList arrayList2 = new ArrayList(rel2.getChildExps());
        ArrayList arrayList3 = new ArrayList(rel.getGroupCount() + rel.getAggCallList().size());
        RelDataTypeFactory typeFactory = rel.getCluster().getTypeFactory();
        Iterator it = rel.getGroupSet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            arrayList3.add(rexBuilder.makeInputRef(((RexNode) rel2.getChildExps().get(intValue)).getType(), intValue));
        }
        for (AggregateCall aggregateCall : rel.getAggCallList()) {
            AggregateCall aggregateCall2 = null;
            if (isOneArgAggregateCall(aggregateCall)) {
                RexCall rexCall = (RexNode) rel2.getChildExps().get(((Integer) Iterables.getOnlyElement(aggregateCall.getArgList())).intValue());
                if (isThreeArgCase(rexCall)) {
                    RexCall rexCall2 = rexCall;
                    boolean z = RexLiteral.isNullLiteral((RexNode) rexCall2.getOperands().get(1)) && !RexLiteral.isNullLiteral((RexNode) rexCall2.getOperands().get(2));
                    RexNode rexNode = (RexNode) rexCall2.getOperands().get(z ? 2 : 1);
                    RexNode rexNode2 = (RexNode) rexCall2.getOperands().get(z ? 1 : 2);
                    RelDataType createSqlType = Calcites.createSqlType(typeFactory, SqlTypeName.BOOLEAN);
                    RexNode makeCall = rexBuilder.makeCall(createSqlType, z ? SqlStdOperatorTable.IS_FALSE : SqlStdOperatorTable.IS_TRUE, ImmutableList.of(rexCall2.getOperands().get(0)));
                    RexNode makeCall2 = aggregateCall.filterArg >= 0 ? rexBuilder.makeCall(createSqlType, SqlStdOperatorTable.AND, ImmutableList.of(rel2.getProjects().get(aggregateCall.filterArg), makeCall)) : makeCall;
                    if (aggregateCall.isDistinct()) {
                        if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(rexNode2)) {
                            arrayList2.add(rexNode);
                            arrayList2.add(makeCall2);
                            aggregateCall2 = AggregateCall.create(SqlStdOperatorTable.COUNT, true, ImmutableList.of(Integer.valueOf(arrayList2.size() - 2)), arrayList2.size() - 1, aggregateCall.getType(), aggregateCall.getName());
                        }
                    } else if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && rexNode.isA(SqlKind.LITERAL) && !RexLiteral.isNullLiteral(rexNode) && RexLiteral.isNullLiteral(rexNode2)) {
                        arrayList2.add(makeCall2);
                        aggregateCall2 = AggregateCall.create(SqlStdOperatorTable.COUNT, false, ImmutableList.of(), arrayList2.size() - 1, aggregateCall.getType(), aggregateCall.getName());
                    } else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM && Calcites.isIntLiteral(rexNode) && RexLiteral.intValue(rexNode) == 1 && Calcites.isIntLiteral(rexNode2) && RexLiteral.intValue(rexNode2) == 0) {
                        arrayList2.add(makeCall2);
                        aggregateCall2 = AggregateCall.create(SqlStdOperatorTable.COUNT, false, ImmutableList.of(), arrayList2.size() - 1, Calcites.createSqlType(typeFactory, SqlTypeName.BIGINT), aggregateCall.getName());
                    } else if (RexLiteral.isNullLiteral(rexNode2) || (aggregateCall.getAggregation().getKind() == SqlKind.SUM && Calcites.isIntLiteral(rexNode2) && RexLiteral.intValue(rexNode2) == 0)) {
                        arrayList2.add(rexNode);
                        arrayList2.add(makeCall2);
                        aggregateCall2 = AggregateCall.create(aggregateCall.getAggregation(), false, ImmutableList.of(Integer.valueOf(arrayList2.size() - 2)), arrayList2.size() - 1, aggregateCall.getType(), aggregateCall.getName());
                    }
                }
            }
            arrayList.add(aggregateCall2 == null ? aggregateCall : aggregateCall2);
            int size = arrayList3.size();
            RelDataType type = ((RelDataTypeField) rel.getRowType().getFieldList().get(size)).getType();
            if (aggregateCall2 == null) {
                arrayList3.add(rexBuilder.makeInputRef(type, size));
            } else {
                arrayList3.add(rexBuilder.makeCast(type, rexBuilder.makeInputRef(aggregateCall2.getType(), size)));
            }
        }
        if (arrayList.equals(rel.getAggCallList())) {
            return;
        }
        RelBuilder project = relOptRuleCall.builder().push(rel2.getInput()).project(arrayList2);
        relOptRuleCall.transformTo(project.aggregate(project.groupKey(rel.getGroupSet(), rel.getGroupSets()), arrayList).project(arrayList3).build());
        relOptRuleCall.getPlanner().setImportance(rel, 0.0d);
    }

    private static boolean isOneArgAggregateCall(AggregateCall aggregateCall) {
        return aggregateCall.getArgList().size() == 1;
    }

    private static boolean isThreeArgCase(RexNode rexNode) {
        return rexNode.getKind() == SqlKind.CASE && ((RexCall) rexNode).getOperands().size() == 3;
    }
}
