/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.azure.cosmos.nosql;

import com.azure.cosmos.CosmosClient;
import com.azure.cosmos.CosmosContainer;
import com.azure.cosmos.CosmosDatabase;
import com.azure.cosmos.models.CosmosBulkOperations;
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosItemOperation;
import com.azure.cosmos.models.CosmosQueryRequestOptions;
import com.azure.cosmos.models.CosmosVectorEmbedding;
import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy;
import com.azure.cosmos.models.CosmosVectorIndexSpec;
import com.azure.cosmos.models.PartitionKey;
import com.azure.cosmos.util.CosmosPagedIterable;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.azure.cosmos.nosql.AzureCosmosDbNoSqlMatchedDocument;
import dev.langchain4j.store.embedding.azure.cosmos.nosql.MappingUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AzureCosmosDbNoSqlEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(AzureCosmosDbNoSqlEmbeddingStore.class);
    private final CosmosClient cosmosClient;
    private final String databaseName;
    private final String containerName;
    private final CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy;
    private final List<CosmosVectorIndexSpec> cosmosVectorIndexes;
    private final CosmosContainerProperties containerProperties;
    private final String embeddingKey;
    private final CosmosDatabase database;
    private final CosmosContainer container;

    public AzureCosmosDbNoSqlEmbeddingStore(CosmosClient cosmosClient, String databaseName, String containerName, CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy, List<CosmosVectorIndexSpec> cosmosVectorIndexes, CosmosContainerProperties containerProperties) {
        this.cosmosClient = cosmosClient;
        this.databaseName = databaseName;
        this.containerName = containerName;
        this.cosmosVectorEmbeddingPolicy = cosmosVectorEmbeddingPolicy;
        this.cosmosVectorIndexes = cosmosVectorIndexes;
        this.containerProperties = containerProperties;
        if (cosmosClient == null) {
            throw new IllegalArgumentException("cosmosClient cannot be null or empty for Azure CosmosDB NoSql Embedding Store.");
        }
        if (Utils.isNullOrBlank((String)databaseName) || Utils.isNullOrBlank((String)containerName)) {
            throw new IllegalArgumentException("databaseName and containerName needs to be provided.");
        }
        if (cosmosVectorEmbeddingPolicy == null || cosmosVectorEmbeddingPolicy.getVectorEmbeddings() == null || cosmosVectorEmbeddingPolicy.getVectorEmbeddings().isEmpty()) {
            throw new IllegalArgumentException("cosmosVectorEmbeddingPolicy cannot be null or empty for Azure CosmosDB NoSql Embedding Store.");
        }
        if (cosmosVectorIndexes == null || cosmosVectorIndexes.isEmpty()) {
            throw new IllegalArgumentException("cosmosVectorIndexes cannot be null or empty for Azure CosmosDB NoSql Embedding Store.");
        }
        this.cosmosClient.createDatabaseIfNotExists(this.databaseName);
        this.database = this.cosmosClient.getDatabase(this.databaseName);
        containerProperties.setVectorEmbeddingPolicy(this.cosmosVectorEmbeddingPolicy);
        containerProperties.getIndexingPolicy().setVectorIndexes(this.cosmosVectorIndexes);
        this.database.createContainerIfNotExists(this.containerProperties);
        this.container = this.database.getContainer(this.containerName);
        this.embeddingKey = ((CosmosVectorEmbedding)this.cosmosVectorEmbeddingPolicy.getVectorEmbeddings().get(0)).getPath().substring(1);
    }

    public static AzureCosmosDbNoSqlEmbeddingStoreBuilder builder() {
        return new AzureCosmosDbNoSqlEmbeddingStoreBuilder();
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAll(ids, embeddings, null);
        return ids;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        if (request.filter() != null) {
            throw new UnsupportedOperationException("EmbeddingSearchRequest.Filter is not supported yet.");
        }
        List<EmbeddingMatch<TextSegment>> matches = this.findRelevant(request.queryEmbedding(), request.maxResults(), request.minScore());
        return new EmbeddingSearchResult(matches);
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        String referenceEmbeddingString = referenceEmbedding.vectorAsList().stream().map(Object::toString).collect(Collectors.joining(","));
        String query = String.format("SELECT TOP %d c.id, c.%s, c.text, c.metadata, VectorDistance(c.%s,[%s]) AS score FROM c ORDER By VectorDistance(c.%s,[%s])", maxResults, this.embeddingKey, this.embeddingKey, referenceEmbeddingString, this.embeddingKey, referenceEmbeddingString);
        CosmosPagedIterable results = this.container.queryItems(query, new CosmosQueryRequestOptions(), AzureCosmosDbNoSqlMatchedDocument.class);
        if (!results.stream().findAny().isPresent()) {
            return new ArrayList<EmbeddingMatch<TextSegment>>();
        }
        return results.stream().map(MappingUtils::toEmbeddingMatch).collect(Collectors.toList());
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("do not add empty embeddings to Azure CosmosDB NoSQL");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        ArrayList<CosmosItemOperation> operations = new ArrayList<CosmosItemOperation>(ids.size());
        for (int i = 0; i < ids.size(); ++i) {
            operations.add(CosmosBulkOperations.getCreateItemOperation((Object)MappingUtils.toNoSqlDbDocument(ids.get(i), embeddings.get(i), embedded == null ? null : embedded.get(i)), (PartitionKey)new PartitionKey((Object)ids.get(i))));
        }
        this.container.executeBulkOperations(operations);
    }

    public static class AzureCosmosDbNoSqlEmbeddingStoreBuilder {
        private CosmosClient cosmosClient;
        private String databaseName;
        private String containerName;
        private CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy;
        private List<CosmosVectorIndexSpec> cosmosVectorIndexes;
        private CosmosContainerProperties containerProperties;

        AzureCosmosDbNoSqlEmbeddingStoreBuilder() {
        }

        public AzureCosmosDbNoSqlEmbeddingStoreBuilder cosmosClient(CosmosClient cosmosClient) {
            this.cosmosClient = cosmosClient;
            return this;
        }

        public AzureCosmosDbNoSqlEmbeddingStoreBuilder databaseName(String databaseName) {
            this.databaseName = databaseName;
            return this;
        }

        public AzureCosmosDbNoSqlEmbeddingStoreBuilder containerName(String containerName) {
            this.containerName = containerName;
            return this;
        }

        public AzureCosmosDbNoSqlEmbeddingStoreBuilder cosmosVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy) {
            this.cosmosVectorEmbeddingPolicy = cosmosVectorEmbeddingPolicy;
            return this;
        }

        public AzureCosmosDbNoSqlEmbeddingStoreBuilder cosmosVectorIndexes(List<CosmosVectorIndexSpec> cosmosVectorIndexes) {
            this.cosmosVectorIndexes = cosmosVectorIndexes;
            return this;
        }

        public AzureCosmosDbNoSqlEmbeddingStoreBuilder containerProperties(CosmosContainerProperties containerProperties) {
            this.containerProperties = containerProperties;
            return this;
        }

        public AzureCosmosDbNoSqlEmbeddingStore build() {
            return new AzureCosmosDbNoSqlEmbeddingStore(this.cosmosClient, this.databaseName, this.containerName, this.cosmosVectorEmbeddingPolicy, this.cosmosVectorIndexes, this.containerProperties);
        }

        public String toString() {
            return "AzureCosmosDbNoSqlEmbeddingStore.AzureCosmosDbNoSqlEmbeddingStoreBuilder(cosmosClient=" + String.valueOf(this.cosmosClient) + ", databaseName=" + this.databaseName + ", containerName=" + this.containerName + ", cosmosVectorEmbeddingPolicy=" + String.valueOf(this.cosmosVectorEmbeddingPolicy) + ", cosmosVectorIndexes=" + String.valueOf(this.cosmosVectorIndexes) + ", containerProperties=" + String.valueOf(this.containerProperties) + ")";
        }
    }
}

