package org.apache.drill.exec.store.mongo;

import com.mongodb.client.model.Accumulators;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.BsonField;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.drill.exec.store.mongo.common.MongoOp;
import org.bson.BsonArray;
import org.bson.BsonDocument;
import org.bson.BsonElement;
import org.bson.BsonInt32;
import org.bson.BsonNull;
import org.bson.BsonString;
import org.bson.Document;
import org.bson.conversions.Bson;

/* loaded from: input_file:org/apache/drill/exec/store/mongo/MongoAggregateUtils.class */
public class MongoAggregateUtils {
    static final /* synthetic */ boolean $assertionsDisabled;

    public static List<String> mongoFieldNames(RelDataType relDataType) {
        return SqlValidatorUtil.uniquify((List) relDataType.getFieldNames().stream().map(str -> {
            return str.startsWith("$") ? "_" + str.substring(2) : str;
        }).collect(Collectors.toList()), true);
    }

    public static String maybeQuote(String str) {
        return !needsQuote(str) ? str : quote(str);
    }

    public static String quote(String str) {
        return "'" + str + "'";
    }

    private static boolean needsQuote(String str) {
        int length = str.length();
        for (int i = 0; i < length; i++) {
            char charAt = str.charAt(i);
            if (!Character.isJavaIdentifierPart(charAt) || charAt == '$') {
                return true;
            }
        }
        return false;
    }

    public static List<Bson> getAggregateOperations(Aggregate aggregate, RelDataType relDataType) {
        String bsonDocument;
        List<String> mongoFieldNames = mongoFieldNames(relDataType);
        List<String> mongoFieldNames2 = mongoFieldNames(aggregate.getRowType());
        if (aggregate.getGroupSet().cardinality() == 1) {
            bsonDocument = "$" + mongoFieldNames.get(aggregate.getGroupSet().nth(0));
        } else {
            Stream stream = StreamSupport.stream(aggregate.getGroupSet().spliterator(), false);
            mongoFieldNames.getClass();
            bsonDocument = new BsonDocument((List) stream.map((v1) -> {
                return r1.get(v1);
            }).map(str -> {
                return new BsonElement(str, new BsonString("$" + str));
            }).collect(Collectors.toList()));
        }
        int cardinality = aggregate.getGroupSet().cardinality();
        ArrayList arrayList = new ArrayList();
        Iterator it = aggregate.getAggCallList().iterator();
        while (it.hasNext()) {
            int i = cardinality;
            cardinality++;
            arrayList.add(bsonAggregate(mongoFieldNames, mongoFieldNames2.get(i), (AggregateCall) it.next()));
        }
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(Aggregates.group(bsonDocument, arrayList).toBsonDocument());
        ArrayList arrayList3 = new ArrayList();
        if (aggregate.getGroupSet().cardinality() == 1) {
            int i2 = 0;
            while (i2 < mongoFieldNames2.size()) {
                String str2 = mongoFieldNames2.get(i2);
                arrayList3.add(new BsonElement(maybeQuote(str2), new BsonString("$" + (i2 == 0 ? DrillMongoConstants.ID : str2))));
                i2++;
            }
        } else {
            arrayList3.add(new BsonElement(DrillMongoConstants.ID, new BsonInt32(0)));
            Iterator it2 = aggregate.getGroupSet().iterator();
            while (it2.hasNext()) {
                int intValue = ((Integer) it2.next()).intValue();
                arrayList3.add(new BsonElement(maybeQuote(mongoFieldNames2.get(intValue)), new BsonString("$_id." + mongoFieldNames2.get(intValue))));
            }
            int cardinality2 = aggregate.getGroupSet().cardinality();
            for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
                int i3 = cardinality2;
                cardinality2++;
                String str3 = mongoFieldNames2.get(i3);
                arrayList3.add(new BsonElement(maybeQuote(str3), new BsonString("$" + str3)));
            }
        }
        if (!aggregate.getGroupSet().isEmpty()) {
            arrayList2.add(Aggregates.project(new BsonDocument(arrayList3)).toBsonDocument());
        }
        return arrayList2;
    }

    private static BsonField bsonAggregate(List<String> list, String str, AggregateCall aggregateCall) {
        Integer bsonDocument;
        String name = aggregateCall.getAggregation().getName();
        List argList = aggregateCall.getArgList();
        if (!name.equals(SqlStdOperatorTable.COUNT.getName())) {
            BiFunction mongoAccumulator = mongoAccumulator(name);
            if (mongoAccumulator != null) {
                return (BsonField) mongoAccumulator.apply(maybeQuote(str), "$" + list.get(((Integer) argList.get(0)).intValue()));
            }
            return null;
        }
        if (argList.size() == 0) {
            bsonDocument = 1;
        } else {
            if (!$assertionsDisabled && argList.size() != 1) {
                throw new AssertionError();
            }
            bsonDocument = new BsonDocument(MongoOp.COND.getCompareOp(), new BsonArray(Arrays.asList(new Document(MongoOp.EQUAL.getCompareOp(), new BsonArray(Arrays.asList(new BsonString(quote(list.get(((Integer) argList.get(0)).intValue()))), BsonNull.VALUE))).toBsonDocument(), new BsonInt32(0), new BsonInt32(1))));
        }
        return Accumulators.sum(maybeQuote(str), bsonDocument);
    }

    private static <T> BiFunction<String, T, BsonField> mongoAccumulator(String str) {
        if (str.equals(SqlStdOperatorTable.SUM.getName()) || str.equals(SqlStdOperatorTable.SUM0.getName())) {
            return Accumulators::sum;
        }
        if (str.equals(SqlStdOperatorTable.MIN.getName())) {
            return Accumulators::min;
        }
        if (str.equals(SqlStdOperatorTable.MAX.getName())) {
            return Accumulators::max;
        }
        if (str.equals(SqlStdOperatorTable.AVG.getName())) {
            return Accumulators::avg;
        }
        if (str.equals(SqlStdOperatorTable.FIRST.getName())) {
            return Accumulators::first;
        }
        if (str.equals(SqlStdOperatorTable.LAST.getName())) {
            return Accumulators::last;
        }
        if (str.equals(SqlStdOperatorTable.STDDEV.getName()) || str.equals(SqlStdOperatorTable.STDDEV_SAMP.getName())) {
            return Accumulators::stdDevSamp;
        }
        if (str.equals(SqlStdOperatorTable.STDDEV_POP.getName())) {
            return Accumulators::stdDevPop;
        }
        return null;
    }

    public static boolean supportsAggregation(AggregateCall aggregateCall) {
        String name = aggregateCall.getAggregation().getName();
        return name.equals(SqlStdOperatorTable.COUNT.getName()) || name.equals(SqlStdOperatorTable.SUM.getName()) || name.equals(SqlStdOperatorTable.SUM0.getName()) || name.equals(SqlStdOperatorTable.MIN.getName()) || name.equals(SqlStdOperatorTable.MAX.getName()) || name.equals(SqlStdOperatorTable.AVG.getName()) || name.equals(SqlStdOperatorTable.FIRST.getName()) || name.equals(SqlStdOperatorTable.LAST.getName()) || name.equals(SqlStdOperatorTable.STDDEV.getName()) || name.equals(SqlStdOperatorTable.STDDEV_SAMP.getName()) || name.equals(SqlStdOperatorTable.STDDEV_POP.getName());
    }

    static {
        $assertionsDisabled = !MongoAggregateUtils.class.desiredAssertionStatus();
    }
}
