package dev.langchain4j.store.embedding.pgvector;

import com.pgvector.PGvector;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.sql.DataSource;
import org.postgresql.ds.PGSimpleDataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore.class */
public class PgVectorEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStore.class);
    protected final DataSource datasource;
    protected final String table;
    final MetadataHandler metadataHandler;

    /* loaded from: input_file:dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore$DatasourceBuilder.class */
    public static class DatasourceBuilder {
        private DataSource datasource;
        private String table;
        private Integer dimension;
        private Boolean useIndex;
        private Integer indexListSize;
        private Boolean createTable;
        private Boolean dropTableFirst;
        private MetadataStorageConfig metadataStorageConfig;

        DatasourceBuilder() {
        }

        public DatasourceBuilder datasource(DataSource dataSource) {
            this.datasource = dataSource;
            return this;
        }

        public DatasourceBuilder table(String str) {
            this.table = str;
            return this;
        }

        public DatasourceBuilder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        public DatasourceBuilder useIndex(Boolean bool) {
            this.useIndex = bool;
            return this;
        }

        public DatasourceBuilder indexListSize(Integer num) {
            this.indexListSize = num;
            return this;
        }

        public DatasourceBuilder createTable(Boolean bool) {
            this.createTable = bool;
            return this;
        }

        public DatasourceBuilder dropTableFirst(Boolean bool) {
            this.dropTableFirst = bool;
            return this;
        }

        public DatasourceBuilder metadataStorageConfig(MetadataStorageConfig metadataStorageConfig) {
            this.metadataStorageConfig = metadataStorageConfig;
            return this;
        }

        public PgVectorEmbeddingStore build() {
            return new PgVectorEmbeddingStore(this.datasource, this.table, this.dimension, this.useIndex, this.indexListSize, this.createTable, this.dropTableFirst, this.metadataStorageConfig);
        }

        public String toString() {
            return "PgVectorEmbeddingStore.DatasourceBuilder(datasource=" + String.valueOf(this.datasource) + ", table=" + this.table + ", dimension=" + this.dimension + ", useIndex=" + this.useIndex + ", indexListSize=" + this.indexListSize + ", createTable=" + this.createTable + ", dropTableFirst=" + this.dropTableFirst + ", metadataStorageConfig=" + String.valueOf(this.metadataStorageConfig) + ")";
        }
    }

    /* loaded from: input_file:dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore$PgVectorEmbeddingStoreBuilder.class */
    public static class PgVectorEmbeddingStoreBuilder {
        private String host;
        private Integer port;
        private String user;
        private String password;
        private String database;
        private String table;
        private Integer dimension;
        private Boolean useIndex;
        private Integer indexListSize;
        private Boolean createTable;
        private Boolean dropTableFirst;
        private MetadataStorageConfig metadataStorageConfig;

        PgVectorEmbeddingStoreBuilder() {
        }

        public PgVectorEmbeddingStoreBuilder host(String str) {
            this.host = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder port(Integer num) {
            this.port = num;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder user(String str) {
            this.user = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder password(String str) {
            this.password = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder database(String str) {
            this.database = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder table(String str) {
            this.table = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder useIndex(Boolean bool) {
            this.useIndex = bool;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder indexListSize(Integer num) {
            this.indexListSize = num;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder createTable(Boolean bool) {
            this.createTable = bool;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder dropTableFirst(Boolean bool) {
            this.dropTableFirst = bool;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder metadataStorageConfig(MetadataStorageConfig metadataStorageConfig) {
            this.metadataStorageConfig = metadataStorageConfig;
            return this;
        }

        public PgVectorEmbeddingStore build() {
            return new PgVectorEmbeddingStore(this.host, this.port, this.user, this.password, this.database, this.table, this.dimension, this.useIndex, this.indexListSize, this.createTable, this.dropTableFirst, this.metadataStorageConfig);
        }

        public String toString() {
            return "PgVectorEmbeddingStore.PgVectorEmbeddingStoreBuilder(host=" + this.host + ", port=" + this.port + ", user=" + this.user + ", password=" + this.password + ", database=" + this.database + ", table=" + this.table + ", dimension=" + this.dimension + ", useIndex=" + this.useIndex + ", indexListSize=" + this.indexListSize + ", createTable=" + this.createTable + ", dropTableFirst=" + this.dropTableFirst + ", metadataStorageConfig=" + String.valueOf(this.metadataStorageConfig) + ")";
        }
    }

    protected PgVectorEmbeddingStore(DataSource dataSource, String str, Integer num, Boolean bool, Integer num2, Boolean bool2, Boolean bool3, MetadataStorageConfig metadataStorageConfig) {
        this.datasource = (DataSource) ValidationUtils.ensureNotNull(dataSource, "datasource");
        this.table = ValidationUtils.ensureNotBlank(str, "table");
        this.metadataHandler = MetadataHandlerFactory.get((MetadataStorageConfig) Utils.getOrDefault(metadataStorageConfig, DefaultMetadataStorageConfig.defaultConfig()));
        Boolean bool4 = (Boolean) Utils.getOrDefault(bool, false);
        initTable((Boolean) Utils.getOrDefault(bool3, false), (Boolean) Utils.getOrDefault(bool2, true), bool4, num, num2);
    }

    protected PgVectorEmbeddingStore(String str, Integer num, String str2, String str3, String str4, String str5, Integer num2, Boolean bool, Integer num3, Boolean bool2, Boolean bool3, MetadataStorageConfig metadataStorageConfig) {
        this(createDataSource(str, num, str2, str3, str4), str5, num2, bool, num3, bool2, bool3, metadataStorageConfig);
    }

    private static DataSource createDataSource(String str, Integer num, String str2, String str3, String str4) {
        String ensureNotBlank = ValidationUtils.ensureNotBlank(str, "host");
        Integer valueOf = Integer.valueOf(ValidationUtils.ensureGreaterThanZero(num, "port"));
        String ensureNotBlank2 = ValidationUtils.ensureNotBlank(str2, "user");
        String ensureNotBlank3 = ValidationUtils.ensureNotBlank(str3, "password");
        String ensureNotBlank4 = ValidationUtils.ensureNotBlank(str4, "database");
        PGSimpleDataSource pGSimpleDataSource = new PGSimpleDataSource();
        pGSimpleDataSource.setServerNames(new String[]{ensureNotBlank});
        pGSimpleDataSource.setPortNumbers(new int[]{valueOf.intValue()});
        pGSimpleDataSource.setDatabaseName(ensureNotBlank4);
        pGSimpleDataSource.setUser(ensureNotBlank2);
        pGSimpleDataSource.setPassword(ensureNotBlank3);
        return pGSimpleDataSource;
    }

    protected void initTable(Boolean bool, Boolean bool2, Boolean bool3, Integer num, Integer num2) {
        try {
            Connection connection = getConnection();
            try {
                Statement createStatement = connection.createStatement();
                try {
                    if (bool.booleanValue()) {
                        createStatement.executeUpdate(String.format("DROP TABLE IF EXISTS %s", this.table));
                    }
                    if (bool2.booleanValue()) {
                        createStatement.executeUpdate(String.format("CREATE TABLE IF NOT EXISTS %s (embedding_id UUID PRIMARY KEY, embedding vector(%s), text TEXT NULL, %s )", this.table, Integer.valueOf(ValidationUtils.ensureGreaterThanZero(num, "dimension")), this.metadataHandler.columnDefinitionsString()));
                        this.metadataHandler.createMetadataIndexes(createStatement, this.table);
                    }
                    if (bool3.booleanValue()) {
                        createStatement.executeUpdate(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s)", this.table + "_ivfflat_index", this.table, Integer.valueOf(ValidationUtils.ensureGreaterThanZero(num2, "indexListSize"))));
                    }
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(String.format("Failed to execute '%s'", "init"), e);
        }
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, null);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAll(list2, list, null);
        return list2;
    }

    public void removeAll(Collection<String> collection) {
        ValidationUtils.ensureNotEmpty(collection, "ids");
        String format = String.format("DELETE FROM %s WHERE embedding_id = ANY (?)", this.table);
        try {
            Connection connection = getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(format);
                try {
                    prepareStatement.setArray(1, connection.createArrayOf("uuid", collection.stream().map(UUID::fromString).toArray()));
                    prepareStatement.executeUpdate();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull(filter, "filter");
        String format = String.format("DELETE FROM %s WHERE %s", this.table, this.metadataHandler.whereClause(filter));
        try {
            Connection connection = getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(format);
                try {
                    prepareStatement.executeUpdate();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll() {
        try {
            Connection connection = getConnection();
            try {
                Statement createStatement = connection.createStatement();
                try {
                    createStatement.executeUpdate(String.format("TRUNCATE TABLE %s", this.table));
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        Embedding queryEmbedding = embeddingSearchRequest.queryEmbedding();
        int maxResults = embeddingSearchRequest.maxResults();
        double minScore = embeddingSearchRequest.minScore();
        Filter filter = embeddingSearchRequest.filter();
        ArrayList arrayList = new ArrayList();
        try {
            Connection connection = getConnection();
            try {
                String arrays = Arrays.toString(queryEmbedding.vector());
                String whereClause = filter == null ? "" : this.metadataHandler.whereClause(filter);
                PreparedStatement prepareStatement = connection.prepareStatement(String.format("WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, %s FROM %s %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", arrays, String.join(",", this.metadataHandler.columnsNames()), this.table, whereClause.isEmpty() ? "" : "WHERE " + whereClause, Double.valueOf(minScore), Integer.valueOf(maxResults)));
                try {
                    ResultSet executeQuery = prepareStatement.executeQuery();
                    while (executeQuery.next()) {
                        try {
                            double d = executeQuery.getDouble("score");
                            String string = executeQuery.getString("embedding_id");
                            Embedding embedding = new Embedding(((PGvector) executeQuery.getObject("embedding")).toArray());
                            String string2 = executeQuery.getString("text");
                            TextSegment textSegment = null;
                            if (Utils.isNotNullOrBlank(string2)) {
                                textSegment = TextSegment.from(string2, this.metadataHandler.fromResultSet(executeQuery));
                            }
                            arrayList.add(new EmbeddingMatch(Double.valueOf(d), string, embedding, textSegment));
                        } catch (Throwable th) {
                            if (executeQuery != null) {
                                try {
                                    executeQuery.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    }
                    if (executeQuery != null) {
                        executeQuery.close();
                    }
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                    return new EmbeddingSearchResult<>(arrayList);
                } catch (Throwable th3) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAll(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("Empty embeddings - no ops");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        try {
            Connection connection = getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(String.format("INSERT INTO %s (embedding_id, embedding, text, %s) VALUES (?, ?, ?, %s)ON CONFLICT (embedding_id) DO UPDATE SET embedding = EXCLUDED.embedding,text = EXCLUDED.text,%s;", this.table, String.join(",", this.metadataHandler.columnsNames()), String.join(",", Collections.nCopies(this.metadataHandler.columnsNames().size(), "?")), this.metadataHandler.insertClause()));
                for (int i = 0; i < list.size(); i++) {
                    try {
                        prepareStatement.setObject(1, UUID.fromString(list.get(i)));
                        prepareStatement.setObject(2, new PGvector(list2.get(i).vector()));
                        if (list3 == null || list3.get(i) == null) {
                            prepareStatement.setNull(3, 12);
                            IntStream.range(4, 4 + this.metadataHandler.columnsNames().size()).forEach(i2 -> {
                                try {
                                    prepareStatement.setNull(i2, 1111);
                                } catch (SQLException e) {
                                    throw new RuntimeException(e);
                                }
                            });
                        } else {
                            prepareStatement.setObject(3, list3.get(i).text());
                            this.metadataHandler.setMetadata(prepareStatement, 4, list3.get(i).metadata());
                        }
                        prepareStatement.addBatch();
                    } catch (Throwable th) {
                        if (prepareStatement != null) {
                            try {
                                prepareStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                prepareStatement.executeBatch();
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected Connection getConnection() throws SQLException {
        Connection connection = this.datasource.getConnection();
        Statement createStatement = connection.createStatement();
        try {
            createStatement.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
            if (createStatement != null) {
                createStatement.close();
            }
            PGvector.addVectorType(connection);
            return connection;
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static DatasourceBuilder datasourceBuilder() {
        return new DatasourceBuilder();
    }

    public static PgVectorEmbeddingStoreBuilder builder() {
        return new PgVectorEmbeddingStoreBuilder();
    }

    public PgVectorEmbeddingStore() {
        this.datasource = null;
        this.table = null;
        this.metadataHandler = null;
    }
}
