package org.apache.druid.sql.calcite.rule;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.java.util.common.ISE;

/* loaded from: input_file:org/apache/druid/sql/calcite/rule/ProjectAggregatePruneUnusedCallRule.class */
public class ProjectAggregatePruneUnusedCallRule extends RelOptRule {
    private static final ProjectAggregatePruneUnusedCallRule INSTANCE = new ProjectAggregatePruneUnusedCallRule();

    private ProjectAggregatePruneUnusedCallRule() {
        super(operand(Project.class, operand(Aggregate.class, any()), new RelOptRuleOperand[0]));
    }

    public static ProjectAggregatePruneUnusedCallRule instance() {
        return INSTANCE;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Project rel = relOptRuleCall.rel(0);
        Aggregate rel2 = relOptRuleCall.rel(1);
        ImmutableBitSet bits = RelOptUtil.InputFinder.bits(rel.getChildExps(), (RexNode) null);
        int groupCount = rel2.getGroupCount() + rel2.getAggCallList().size();
        if (groupCount != rel2.getRowType().getFieldCount()) {
            throw new ISE("WTF, expected[%s] to have[%s] fields but it had[%s]", new Object[]{rel2, Integer.valueOf(groupCount), Integer.valueOf(rel2.getRowType().getFieldCount())});
        }
        ImmutableBitSet intersect = bits.intersect(ImmutableBitSet.range(rel2.getGroupCount(), groupCount));
        if (intersect.cardinality() < rel2.getAggCallList().size()) {
            ArrayList arrayList = new ArrayList();
            Iterator it = intersect.iterator();
            while (it.hasNext()) {
                arrayList.add(rel2.getAggCallList().get(((Integer) it.next()).intValue() - rel2.getGroupCount()));
            }
            Aggregate copy = rel2.copy(rel2.getTraitSet(), rel2.getInput(), rel2.getGroupSet(), rel2.getGroupSets(), arrayList);
            ArrayList arrayList2 = new ArrayList();
            RexBuilder rexBuilder = rel2.getCluster().getRexBuilder();
            for (int i = 0; i < rel2.getGroupCount(); i++) {
                arrayList2.add(rexBuilder.makeInputRef(copy, i));
            }
            int groupCount2 = rel2.getGroupCount();
            for (int groupCount3 = rel2.getGroupCount(); groupCount3 < groupCount; groupCount3++) {
                if (intersect.get(groupCount3)) {
                    int i2 = groupCount2;
                    groupCount2++;
                    arrayList2.add(rexBuilder.makeInputRef(copy, i2));
                } else {
                    arrayList2.add(rexBuilder.makeNullLiteral(((RelDataTypeField) rel2.getRowType().getFieldList().get(groupCount3)).getType()));
                }
            }
            relOptRuleCall.transformTo(relOptRuleCall.builder().push(copy).project(arrayList2).project(rel.getChildExps()).build());
            relOptRuleCall.getPlanner().setImportance(rel, 0.0d);
        }
    }
}
