package dev.langchain4j.store.embedding.azure.cosmos.mongo.vcore;

import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCommandException;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.CreateCollectionOptions;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.bson.BsonArray;
import org.bson.BsonDocument;
import org.bson.BsonValue;
import org.bson.Document;
import org.bson.codecs.configuration.CodecProvider;
import org.bson.codecs.configuration.CodecRegistries;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDbMongoVCoreEmbeddingStore.class */
public class AzureCosmosDbMongoVCoreEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(AzureCosmosDbMongoVCoreEmbeddingStore.class);
    private final MongoCollection<AzureCosmosDbMongoVCoreDocument> collection;
    private final String indexName;
    private final VectorIndexType kind;
    private final Integer numLists;
    private final Integer dimensions;
    private final Integer m;
    private final Integer efConstruction;
    private final Integer efSearch;

    /* loaded from: input_file:dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDbMongoVCoreEmbeddingStore$Builder.class */
    public static class Builder {
        private MongoClient mongoClient;
        private String connectionString;
        private String databaseName;
        private String collectionName;
        private String indexName;
        private String applicationName;
        private CreateCollectionOptions createCollectionOptions;
        private Boolean createIndex;
        private String kind;
        private Integer numLists;
        private Integer dimensions;
        private Integer m;
        private Integer efConstruction;
        private Integer efSearch;

        public Builder mongoClient(MongoClient mongoClient) {
            this.mongoClient = mongoClient;
            return this;
        }

        public Builder connectionString(String str) {
            this.connectionString = str;
            return this;
        }

        public Builder databaseName(String str) {
            this.databaseName = str;
            return this;
        }

        public Builder collectionName(String str) {
            this.collectionName = str;
            return this;
        }

        public Builder indexName(String str) {
            this.indexName = str;
            return this;
        }

        public Builder applicationName(String str) {
            this.applicationName = str;
            return this;
        }

        public Builder createCollectionOptions(CreateCollectionOptions createCollectionOptions) {
            this.createCollectionOptions = createCollectionOptions;
            return this;
        }

        public Builder createIndex(Boolean bool) {
            this.createIndex = bool;
            return this;
        }

        public Builder kind(String str) {
            this.kind = str;
            return this;
        }

        public Builder numLists(Integer num) {
            this.numLists = num;
            return this;
        }

        public Builder dimensions(Integer num) {
            this.dimensions = num;
            return this;
        }

        public Builder m(Integer num) {
            this.m = num;
            return this;
        }

        public Builder efConstruction(Integer num) {
            this.efConstruction = num;
            return this;
        }

        public Builder efSearch(Integer num) {
            this.efSearch = num;
            return this;
        }

        public AzureCosmosDbMongoVCoreEmbeddingStore build() {
            return new AzureCosmosDbMongoVCoreEmbeddingStore(this.mongoClient, this.connectionString, this.databaseName, this.collectionName, this.indexName, this.applicationName, this.createCollectionOptions, this.createIndex, this.kind, this.numLists, this.dimensions, this.m, this.efConstruction, this.efSearch);
        }
    }

    /* loaded from: input_file:dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDbMongoVCoreEmbeddingStore$SimilarityMetric.class */
    public enum SimilarityMetric {
        COS("COS");

        private final String value;

        SimilarityMetric(String str) {
            this.value = str;
        }

        public String getValue() {
            return this.value;
        }

        public static SimilarityMetric fromString(String str) {
            return (SimilarityMetric) Arrays.stream(values()).filter(similarityMetric -> {
                return similarityMetric.getValue().equals(str);
            }).findFirst().orElseThrow(() -> {
                return new IllegalArgumentException("This similarity metric is not supported: " + str);
            });
        }
    }

    /* loaded from: input_file:dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDbMongoVCoreEmbeddingStore$VectorIndexType.class */
    public enum VectorIndexType {
        VECTOR_IVF("vector-ivf"),
        VECTOR_HNSW("vector-hnsw");

        private final String value;

        VectorIndexType(String str) {
            this.value = str;
        }

        public String getValue() {
            return this.value;
        }

        public static VectorIndexType fromString(String str) {
            return (VectorIndexType) Arrays.stream(values()).filter(vectorIndexType -> {
                return vectorIndexType.getValue().equals(str);
            }).findFirst().orElseThrow(() -> {
                return new IllegalArgumentException("This vector index type is not supported: " + str);
            });
        }
    }

    public AzureCosmosDbMongoVCoreEmbeddingStore(MongoClient mongoClient, String str, String str2, String str3, String str4, String str5, CreateCollectionOptions createCollectionOptions, Boolean bool, String str6, Integer num, Integer num2, Integer num3, Integer num4, Integer num5) {
        if (mongoClient == null && (str == null || str.isEmpty())) {
            throw new IllegalArgumentException("You need to pass either the mongoClient or the connectionString required for connecting to Azure CosmosDB Mongo vCore");
        }
        if (str2 == null || str2.isEmpty() || str3 == null || str3.isEmpty()) {
            throw new IllegalArgumentException("databaseName and collectionName needs to be provided.");
        }
        Boolean bool2 = (Boolean) Utils.getOrDefault(bool, false);
        this.indexName = (String) Utils.getOrDefault(str4, "defaultIndexAzureCosmos");
        String str7 = (String) Utils.getOrDefault(str5, "LangChain4j");
        this.kind = VectorIndexType.fromString(str6);
        this.numLists = (Integer) Utils.getOrDefault(num, 1);
        this.dimensions = (Integer) Utils.getOrDefault(num2, 1536);
        this.m = (Integer) Utils.getOrDefault(num3, 16);
        this.efConstruction = (Integer) Utils.getOrDefault(num4, 64);
        this.efSearch = (Integer) Utils.getOrDefault(num5, 40);
        CodecRegistry fromRegistries = CodecRegistries.fromRegistries(new CodecRegistry[]{MongoClientSettings.getDefaultCodecRegistry(), CodecRegistries.fromProviders(new CodecProvider[]{PojoCodecProvider.builder().register(new Class[]{AzureCosmosDbMongoVCoreDocument.class, BsonDocument.class}).build()})});
        MongoDatabase database = (mongoClient == null ? MongoClients.create(MongoClientSettings.builder().applyConnectionString(new ConnectionString(str)).applicationName(str7).build()) : mongoClient).getDatabase(str2);
        if (!isCollectionExist(database, str3)) {
            createCollection(database, str3, (CreateCollectionOptions) Utils.getOrDefault(createCollectionOptions, new CreateCollectionOptions()));
        }
        this.collection = database.getCollection(str3, AzureCosmosDbMongoVCoreDocument.class).withCodecRegistry(fromRegistries);
        if (!Boolean.TRUE.equals(bool2) || isIndexExist(this.indexName)) {
            return;
        }
        createIndex(this.indexName, str3, database);
    }

    public static Builder builder() {
        return new Builder();
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        List<Bson> arrayList = new ArrayList();
        switch (this.kind) {
            case VECTOR_IVF:
                arrayList = getPipelineDefinitionVectorIVF(embedding, i);
                break;
            case VECTOR_HNSW:
                arrayList = getPipelineDefinitionVectorHNSW(embedding, i);
                break;
        }
        try {
            return (List) StreamSupport.stream(this.collection.aggregate(arrayList, BsonDocument.class).spliterator(), false).filter(bsonDocument -> {
                return RelevanceScore.fromCosineSimilarity(bsonDocument.getDouble("similarityScore").getValue()) >= d;
            }).map(bsonDocument2 -> {
                return MappingUtils.toEmbeddingMatch(mapBsonToAzureCosmosDbMongoVCoreMatchedDocument(bsonDocument2.getDocument("document"), Double.valueOf(bsonDocument2.getDouble("similarityScore").getValue())));
            }).collect(Collectors.toList());
        } catch (MongoCommandException e) {
            throw new RuntimeException("Error in AzureCosmosDbMongoVCoreEmbeddingStore.findRelevant", e);
        }
    }

    private List<Bson> getPipelineDefinitionVectorIVF(Embedding embedding, int i) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Document("$search", new Document("cosmosSearch", new Document("vector", embedding.vectorAsList()).append("path", "embedding").append("k", Integer.valueOf(i))).append("returnStoredSource", true)));
        arrayList.add(new Document("$project", new Document("similarityScore", new Document("$meta", "searchScore")).append("document", "$$ROOT")));
        return arrayList;
    }

    private List<Bson> getPipelineDefinitionVectorHNSW(Embedding embedding, int i) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Document("$search", new Document("cosmosSearch", new Document("vector", embedding.vectorAsList()).append("path", "embedding").append("k", Integer.valueOf(i)).append("efSearch", this.efSearch))));
        arrayList.add(new Document("$project", new Document("similarityScore", new Document("$meta", "searchScore")).append("document", "$$ROOT")));
        return arrayList;
    }

    private AzureCosmosDbMongoVCoreMatchedDocument mapBsonToAzureCosmosDbMongoVCoreMatchedDocument(BsonDocument bsonDocument, Double d) {
        AzureCosmosDbMongoVCoreMatchedDocument azureCosmosDbMongoVCoreMatchedDocument = new AzureCosmosDbMongoVCoreMatchedDocument();
        azureCosmosDbMongoVCoreMatchedDocument.setId(bsonDocument.getString("_id").getValue());
        ArrayList arrayList = new ArrayList();
        Iterator it = bsonDocument.getArray("embedding").iterator();
        while (it.hasNext()) {
            arrayList.add(Float.valueOf((float) ((BsonValue) it.next()).asDouble().getValue()));
        }
        azureCosmosDbMongoVCoreMatchedDocument.setEmbedding(arrayList);
        if (bsonDocument.containsKey("text")) {
            azureCosmosDbMongoVCoreMatchedDocument.setText(bsonDocument.getString("text").getValue());
        }
        if (bsonDocument.containsKey("metadata")) {
            HashMap hashMap = new HashMap();
            BsonDocument document = bsonDocument.getDocument("metadata");
            for (String str : document.keySet()) {
                hashMap.put(str, document.getString(str).getValue());
            }
            azureCosmosDbMongoVCoreMatchedDocument.setMetadata(hashMap);
        }
        azureCosmosDbMongoVCoreMatchedDocument.setScore(Double.valueOf(RelevanceScore.fromCosineSimilarity(d.doubleValue())));
        return azureCosmosDbMongoVCoreMatchedDocument;
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("do not add empty embeddings to Azure CosmosDB  Mongo vCore");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        ArrayList arrayList = new ArrayList(list.size());
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(MappingUtils.toMongoDbDocument(list.get(i), list2.get(i), list3 == null ? null : list3.get(i)));
        }
        if (!this.collection.insertMany(arrayList).wasAcknowledged()) {
            throw new RuntimeException(String.format("[AzureCosmosDbMongoVCoreEmbeddingStore] Add document failed, Document=%s", arrayList));
        }
    }

    private boolean isCollectionExist(MongoDatabase mongoDatabase, String str) {
        Stream stream = StreamSupport.stream(mongoDatabase.listCollectionNames().spliterator(), false);
        Objects.requireNonNull(str);
        return stream.anyMatch((v1) -> {
            return r1.equals(v1);
        });
    }

    private void createCollection(MongoDatabase mongoDatabase, String str, CreateCollectionOptions createCollectionOptions) {
        mongoDatabase.createCollection(str, createCollectionOptions);
    }

    private boolean isIndexExist(String str) {
        return StreamSupport.stream(this.collection.listIndexes().spliterator(), false).anyMatch(document -> {
            return str.equals(document.getString("name"));
        });
    }

    private void createIndex(String str, String str2, MongoDatabase mongoDatabase) {
        BsonDocument document = new Document();
        switch (this.kind) {
            case VECTOR_IVF:
                document = getIndexDefinitionVectorIVF(str, str2);
                break;
            case VECTOR_HNSW:
                document = getIndexDefinitionVectorHNSW(str, str2);
                break;
        }
        mongoDatabase.runCommand(document);
    }

    private BsonDocument getIndexDefinitionVectorIVF(String str, String str2) {
        BsonDocument bsonDocument = new Document().append("name", str).append("key", new Document("embedding", "cosmosSearch")).append("cosmosSearchOptions", new Document().append("kind", this.kind.getValue()).append("numLists", this.numLists).append("similarity", SimilarityMetric.COS).append("dimensions", this.dimensions)).toBsonDocument();
        BsonArray bsonArray = new BsonArray();
        bsonArray.add(bsonDocument);
        return new Document().append("createIndexes", str2).append("indexes", bsonArray).toBsonDocument();
    }

    private BsonDocument getIndexDefinitionVectorHNSW(String str, String str2) {
        BsonDocument bsonDocument = new Document().append("name", str).append("key", new Document("embedding", "cosmosSearch")).append("cosmosSearchOptions", new Document().append("kind", this.kind.getValue()).append("m", this.m).append("efConstruction", this.efConstruction).append("similarity", SimilarityMetric.COS).append("dimensions", this.dimensions)).toBsonDocument();
        BsonArray bsonArray = new BsonArray();
        bsonArray.add(bsonDocument);
        return new Document().append("createIndexes", str2).append("indexes", bsonArray).toBsonDocument();
    }
}
