/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.store.embedding.duckdb;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.community.store.embedding.duckdb.DuckDBMetadataFilterMapper;
import dev.langchain4j.community.store.embedding.duckdb.DuckDBSQLException;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.sql.Array;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.duckdb.DuckDBConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DuckDBEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(DuckDBEmbeddingStore.class);
    private static final String CREATE_TABLE_TEMPLATE = "create table if not exists %s (id UUID, embedding FLOAT[], text TEXT NULL, metadata JSON NULL);\n";
    private static final String SEARCH_QUERY_TEMPLATE = "select id, embedding, text, metadata, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score\nfrom %s\nwhere score >= %s %s\norder by score DESC\nlimit %d\n";
    private static final String INSERT_QUERY_TEMPLATE = "insert into %s (id, embedding, text, metadata) values (?,?,?,?)\n";
    private static final String DELETE_BY_IDS_QUERY_TEMPLATE = "delete from %s where id in ?\n";
    private static final String DELETE_QUERY_TEMPLATE = "delete from %s where %s\n";
    private static final String TRUNCATE_QUERY_TEMPLATE = "truncate table %s\n";
    private final String tableName;
    private final DuckDBConnection duckDBConnection;
    private final DuckDBMetadataFilterMapper jsonFilterMapper = new DuckDBMetadataFilterMapper();
    private final ObjectMapper jsonMetadataSerializer = new ObjectMapper();

    public DuckDBEmbeddingStore(String filePath, String tableName) {
        try {
            Object dbUrl = filePath != null ? "jdbc:duckdb:" + filePath : "jdbc:duckdb:";
            this.tableName = (String)Utils.getOrDefault((Object)tableName, (Object)"embeddings");
            this.duckDBConnection = (DuckDBConnection)DriverManager.getConnection((String)dbUrl);
            this.initTable();
        }
        catch (SQLException e) {
            throw new DuckDBSQLException("Unable to load duckdb connection", e);
        }
    }

    public static DuckDBEmbeddingStore inMemory() {
        return new DuckDBEmbeddingStore(null, null);
    }

    public static Builder builder() {
        return new Builder();
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        return this.addAll(embeddings, null);
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).toList();
        this.addAll(ids, embeddings, embedded);
        return ids;
    }

    public void removeAll(Collection<String> ids) {
        ValidationUtils.ensureNotEmpty(ids, (String)"ids");
        String sql = String.format(DELETE_BY_IDS_QUERY_TEMPLATE, this.tableName);
        try (Connection connection = this.duckDBConnection.duplicate();
             PreparedStatement statement = connection.prepareStatement(sql);){
            Array idsParam = connection.createArrayOf("UUID", ids.toArray());
            statement.setObject(1, idsParam);
            statement.execute();
        }
        catch (SQLException e) {
            throw new DuckDBSQLException("Unable to remove embeddings by ids", e);
        }
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull((Object)filter, (String)"filter");
        String whereClause = this.jsonFilterMapper.map(filter);
        String sql = String.format(DELETE_QUERY_TEMPLATE, this.tableName, whereClause);
        try (Connection connection = this.duckDBConnection.duplicate();
             PreparedStatement statement = connection.prepareStatement(sql);){
            log.debug(sql);
            statement.execute();
        }
        catch (SQLException e) {
            throw new DuckDBSQLException("Unable to remove embeddings with filter", e);
        }
    }

    public void removeAll() {
        String sql = String.format(TRUNCATE_QUERY_TEMPLATE, this.tableName);
        try (Connection connection = this.duckDBConnection.duplicate();
             Statement statement = connection.createStatement();){
            statement.execute(sql);
        }
        catch (SQLException e) {
            throw new DuckDBSQLException("Unable to remove all embeddings", e);
        }
    }

    /*
     * Enabled aggressive exception aggregation
     */
    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        String param = this.embeddingToParam(request.queryEmbedding());
        String filterClause = request.filter() != null ? "and " + this.jsonFilterMapper.map(request.filter()) : "";
        String query = String.format(SEARCH_QUERY_TEMPLATE, param, this.tableName, request.minScore(), filterClause, request.maxResults());
        try (Connection connection = this.duckDBConnection.duplicate();){
            EmbeddingSearchResult embeddingSearchResult;
            block16: {
                PreparedStatement statement = connection.prepareStatement(query);
                try {
                    ArrayList<EmbeddingMatch> matches = new ArrayList<EmbeddingMatch>();
                    log.debug(query);
                    ResultSet resultSet = statement.executeQuery();
                    while (resultSet.next()) {
                        String id = resultSet.getString("id");
                        String text = resultSet.getString("text");
                        double score = resultSet.getDouble("score");
                        Array sqlArray = resultSet.getArray("embedding");
                        String metadataJson = resultSet.getString("metadata");
                        TypeReference<HashMap<String, Object>> typeReference = new TypeReference<HashMap<String, Object>>(){};
                        Map metadataMap = metadataJson != null ? (Map)this.jsonMetadataSerializer.readValue(metadataJson, (TypeReference)typeReference) : Collections.emptyMap();
                        Object[] sqlList = (Object[])sqlArray.getArray();
                        float[] vector = new float[sqlList.length];
                        for (int i = 0; i < sqlList.length; ++i) {
                            vector[i] = ((Float)sqlList[i]).floatValue();
                        }
                        TextSegment ts = text != null ? TextSegment.from((String)text, (Metadata)Metadata.from((Map)metadataMap)) : null;
                        matches.add(new EmbeddingMatch(Double.valueOf(score), id, new Embedding(vector), (Object)ts));
                    }
                    embeddingSearchResult = new EmbeddingSearchResult(matches);
                    if (statement == null) break block16;
                }
                catch (Throwable throwable) {
                    if (statement != null) {
                        try {
                            statement.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                statement.close();
            }
            return embeddingSearchResult;
        }
        catch (JsonProcessingException | SQLException e) {
            throw new DuckDBSQLException("Error while searching embeddings", e);
        }
    }

    private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), embedding == null ? null : Collections.singletonList(textSegment));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("[no embeddings to add to DuckDB]");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        try (Connection connection = this.duckDBConnection.duplicate();
             PreparedStatement statement = connection.prepareStatement(String.format(INSERT_QUERY_TEMPLATE, this.tableName));){
            for (int i = 0; i < ids.size(); ++i) {
                String textParam = null;
                if (embedded != null && embedded.get(i) != null) {
                    textParam = embedded.get(i).text();
                }
                Map metadata = embedded != null && embedded.get(i) != null ? embedded.get(i).metadata().toMap() : null;
                statement.setString(1, ids.get(i));
                Array embeddingsParam = connection.createArrayOf("float", embeddings.get(i).vectorAsList().toArray());
                statement.setObject(2, embeddingsParam);
                statement.setString(3, textParam);
                statement.setString(4, this.jsonMetadataSerializer.writeValueAsString((Object)metadata));
                statement.addBatch();
            }
            statement.executeBatch();
        }
        catch (JsonProcessingException | SQLException e) {
            throw new DuckDBSQLException("Unable to add embeddings in DuckDB", e);
        }
    }

    private void initTable() {
        String sql = String.format(CREATE_TABLE_TEMPLATE, this.tableName);
        try (Connection connection = this.duckDBConnection.duplicate();
             Statement statement = connection.createStatement();){
            log.debug(sql);
            statement.execute(sql);
        }
        catch (SQLException e) {
            throw new DuckDBSQLException(String.format("Failed to init duckDB table:  '%s'", sql), e);
        }
    }

    protected String embeddingToParam(Embedding embedding) {
        return embedding.vectorAsList().stream().map(Object::toString).collect(Collectors.joining(",", "[", "]")).concat("::float[]");
    }

    public static class Builder {
        private String filePath;
        private String tableName;

        public Builder filePath(String filePath) {
            this.filePath = filePath;
            return this;
        }

        public Builder tableName(String tableName) {
            this.tableName = tableName;
            return this;
        }

        public Builder inMemory(String tableName) {
            return this.filePath(null);
        }

        public DuckDBEmbeddingStore build() {
            return new DuckDBEmbeddingStore(this.filePath, this.tableName);
        }
    }
}

