package com.microsoft.semantickernel.connectors.data.postgres;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreRecordCollectionSearchMapping;
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordField;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordKeyField;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordVectorField;
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
import com.microsoft.semantickernel.data.vectorstorage.options.UpsertRecordOptions;
import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions;
import com.microsoft.semantickernel.exceptions.SKException;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.sql.DataSource;

/* loaded from: input_file:com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.class */
public class PostgreSQLVectorStoreQueryProvider extends JDBCVectorStoreQueryProvider implements SQLVectorStoreQueryProvider {
    private final Map<Class<?>, String> supportedKeyTypes;
    private final Map<Class<?>, String> supportedDataTypes;
    private final Map<Class<?>, String> supportedVectorTypes;
    private final DataSource dataSource;
    private final String collectionsTable;
    private final String prefixForCollectionTables;
    private final ObjectMapper objectMapper;

    /* loaded from: input_file:com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider$Builder.class */
    public static class Builder extends JDBCVectorStoreQueryProvider.Builder {
        private DataSource dataSource;
        private String collectionsTable = SQLVectorStoreQueryProvider.DEFAULT_COLLECTIONS_TABLE;
        private String prefixForCollectionTables = SQLVectorStoreQueryProvider.DEFAULT_PREFIX_FOR_COLLECTION_TABLES;
        private ObjectMapper objectMapper = new ObjectMapper();

        @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.Builder
        @SuppressFBWarnings({"EI_EXPOSE_REP2"})
        public Builder withDataSource(DataSource dataSource) {
            this.dataSource = dataSource;
            return this;
        }

        @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.Builder
        public Builder withCollectionsTable(String str) {
            this.collectionsTable = JDBCVectorStoreQueryProvider.validateSQLidentifier(str);
            return this;
        }

        @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.Builder
        public Builder withPrefixForCollectionTables(String str) {
            this.prefixForCollectionTables = JDBCVectorStoreQueryProvider.validateSQLidentifier(str);
            return this;
        }

        @SuppressFBWarnings({"EI_EXPOSE_REP2"})
        public Builder withObjectMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
            return this;
        }

        @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.Builder
        /* renamed from: build */
        public PostgreSQLVectorStoreQueryProvider mo2build() {
            if (this.dataSource == null) {
                throw new SKException("DataSource is required");
            }
            return new PostgreSQLVectorStoreQueryProvider(this.dataSource, this.collectionsTable, this.prefixForCollectionTables, this.objectMapper);
        }
    }

    @SuppressFBWarnings({"EI_EXPOSE_REP2"})
    private PostgreSQLVectorStoreQueryProvider(@Nonnull DataSource dataSource, @Nonnull String str, @Nonnull String str2, @Nonnull ObjectMapper objectMapper) {
        super(dataSource, str, str2);
        this.dataSource = dataSource;
        this.collectionsTable = str;
        this.prefixForCollectionTables = str2;
        this.objectMapper = objectMapper;
        this.supportedKeyTypes = new HashMap();
        this.supportedKeyTypes.put(String.class, "VARCHAR(255)");
        this.supportedDataTypes = new HashMap();
        this.supportedDataTypes.put(String.class, "TEXT");
        this.supportedDataTypes.put(Integer.class, "INTEGER");
        this.supportedDataTypes.put(Integer.TYPE, "INTEGER");
        this.supportedDataTypes.put(Long.class, "BIGINT");
        this.supportedDataTypes.put(Long.TYPE, "BIGINT");
        this.supportedDataTypes.put(Float.class, "REAL");
        this.supportedDataTypes.put(Float.TYPE, "REAL");
        this.supportedDataTypes.put(Double.class, "DOUBLE PRECISION");
        this.supportedDataTypes.put(Double.TYPE, "DOUBLE PRECISION");
        this.supportedDataTypes.put(Boolean.class, "BOOLEAN");
        this.supportedDataTypes.put(Boolean.TYPE, "BOOLEAN");
        this.supportedDataTypes.put(OffsetDateTime.class, "TIMESTAMPTZ");
        this.supportedVectorTypes = new HashMap();
        this.supportedDataTypes.put(String.class, "TEXT");
        this.supportedVectorTypes.put(List.class, "VECTOR(%d)");
        this.supportedVectorTypes.put(Collection.class, "VECTOR(%d)");
    }

    @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider, com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider
    public Map<Class<?>, String> getSupportedKeyTypes() {
        return new HashMap(this.supportedKeyTypes);
    }

    @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider, com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider
    public Map<Class<?>, String> getSupportedDataTypes() {
        return new HashMap(this.supportedDataTypes);
    }

    @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider, com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider
    public Map<Class<?>, String> getSupportedVectorTypes() {
        return new HashMap(this.supportedVectorTypes);
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider, com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider
    public void prepareVectorStore() {
        super.prepareVectorStore();
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement("CREATE EXTENSION IF NOT EXISTS vector");
                try {
                    prepareStatement.execute();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new SKException("Failed to prepare vector store", e);
        }
    }

    private String getColumnNamesAndTypesForVectorFields(List<VectorStoreRecordVectorField> list) {
        return (String) list.stream().map(vectorStoreRecordVectorField -> {
            return validateSQLidentifier(vectorStoreRecordVectorField.getEffectiveStorageName()) + " " + (vectorStoreRecordVectorField.getFieldType().equals(String.class) ? this.supportedVectorTypes.get(String.class) : String.format(this.supportedVectorTypes.get(vectorStoreRecordVectorField.getFieldType()), Integer.valueOf(vectorStoreRecordVectorField.getDimensions())));
        }).collect(Collectors.joining(", "));
    }

    private String createIndexForVectorField(String str, VectorStoreRecordVectorField vectorStoreRecordVectorField) {
        PostgreSQLVectorIndexKind fromIndexKind = PostgreSQLVectorIndexKind.fromIndexKind(vectorStoreRecordVectorField.getIndexKind());
        PostgreSQLVectorDistanceFunction fromDistanceFunction = PostgreSQLVectorDistanceFunction.fromDistanceFunction(vectorStoreRecordVectorField.getDistanceFunction());
        if (fromIndexKind == null) {
            return null;
        }
        if (fromDistanceFunction == null) {
            throw new SKException("Distance function is required for vector field: " + vectorStoreRecordVectorField.getName());
        }
        return formatQuery("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (%s %s);", getCollectionTableName(str) + "_index", getCollectionTableName(str), fromIndexKind.getValue(), vectorStoreRecordVectorField.getEffectiveStorageName(), fromDistanceFunction.getValue());
    }

    @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider, com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider
    @SuppressFBWarnings({"SQL_NONCONSTANT_STRING_PASSED_TO_EXECUTE", "SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING"})
    public void createCollection(String str, VectorStoreRecordDefinition vectorStoreRecordDefinition) {
        List<VectorStoreRecordVectorField> vectorFields = vectorStoreRecordDefinition.getVectorFields();
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                Statement createStatement = connection.createStatement();
                try {
                    createStatement.addBatch(formatQuery("CREATE TABLE IF NOT EXISTS %s (%s VARCHAR(255) PRIMARY KEY, %s, %s);", getCollectionTableName(str), getKeyColumnName(vectorStoreRecordDefinition.getKeyField()), getColumnNamesAndTypes(new ArrayList(vectorStoreRecordDefinition.getDataFields()), this.supportedDataTypes), getColumnNamesAndTypesForVectorFields(vectorStoreRecordDefinition.getVectorFields())));
                    Iterator<VectorStoreRecordVectorField> it = vectorFields.iterator();
                    while (it.hasNext()) {
                        String createIndexForVectorField = createIndexForVectorField(str, it.next());
                        if (createIndexForVectorField != null) {
                            createStatement.addBatch(createIndexForVectorField);
                        }
                    }
                    createStatement.executeBatch();
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                    String formatQuery = formatQuery("INSERT INTO %s (collectionId) VALUES (?)", validateSQLidentifier(this.collectionsTable));
                    try {
                        connection = this.dataSource.getConnection();
                        try {
                            PreparedStatement prepareStatement = connection.prepareStatement(formatQuery);
                            try {
                                prepareStatement.setObject(1, str);
                                prepareStatement.execute();
                                if (prepareStatement != null) {
                                    prepareStatement.close();
                                }
                                if (connection != null) {
                                    connection.close();
                                }
                            } catch (Throwable th) {
                                if (prepareStatement != null) {
                                    try {
                                        prepareStatement.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                }
                                throw th;
                            }
                        } finally {
                            if (connection != null) {
                                try {
                                    connection.close();
                                } catch (Throwable th3) {
                                    th.addSuppressed(th3);
                                }
                            }
                        }
                    } catch (SQLException e) {
                        throw new SKException("Failed to insert collection", e);
                    }
                } catch (Throwable th4) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th5) {
                            th4.addSuppressed(th5);
                        }
                    }
                    throw th4;
                }
            } finally {
            }
        } catch (SQLException e2) {
            throw new SKException("Failed to create collection", e2);
        }
    }

    private void setUpsertStatementValues(PreparedStatement preparedStatement, Object obj, List<VectorStoreRecordField> list) {
        JsonNode valueToTree = this.objectMapper.valueToTree(obj);
        for (int i = 0; i < list.size(); i++) {
            VectorStoreRecordField vectorStoreRecordField = list.get(i);
            try {
                JsonNode jsonNode = valueToTree.get(vectorStoreRecordField.getEffectiveStorageName());
                if (!(vectorStoreRecordField instanceof VectorStoreRecordVectorField) || vectorStoreRecordField.getFieldType().equals(String.class)) {
                    preparedStatement.setObject(i + 1, this.objectMapper.convertValue(jsonNode, vectorStoreRecordField.getFieldType()));
                } else {
                    preparedStatement.setObject(i + 1, this.objectMapper.writeValueAsString(jsonNode));
                }
            } catch (SQLException | JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private String getWildcardStringWithCast(List<VectorStoreRecordField> list) {
        return (String) list.stream().map(vectorStoreRecordField -> {
            String str;
            str = "?";
            return vectorStoreRecordField instanceof VectorStoreRecordVectorField ? str + "::vector" : "?";
        }).collect(Collectors.joining(", "));
    }

    @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider, com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider
    @SuppressFBWarnings({"SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING"})
    public void upsertRecords(String str, List<?> list, VectorStoreRecordDefinition vectorStoreRecordDefinition, UpsertRecordOptions upsertRecordOptions) {
        validateSQLidentifier(getCollectionTableName(str));
        List<VectorStoreRecordField> allFields = vectorStoreRecordDefinition.getAllFields();
        String formatQuery = formatQuery("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s", getCollectionTableName(str), getQueryColumnsFromFields(allFields), getWildcardStringWithCast(allFields), getKeyColumnName(vectorStoreRecordDefinition.getKeyField()), (String) allFields.stream().filter(vectorStoreRecordField -> {
            return !(vectorStoreRecordField instanceof VectorStoreRecordKeyField);
        }).map(vectorStoreRecordField2 -> {
            return formatQuery("%s = EXCLUDED.%s", validateSQLidentifier(vectorStoreRecordField2.getEffectiveStorageName()), vectorStoreRecordField2.getEffectiveStorageName());
        }).collect(Collectors.joining(", ")));
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(formatQuery);
                try {
                    Iterator<?> it = list.iterator();
                    while (it.hasNext()) {
                        setUpsertStatementValues(prepareStatement, it.next(), vectorStoreRecordDefinition.getAllFields());
                        prepareStatement.addBatch();
                    }
                    prepareStatement.executeBatch();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new SKException("Failed to upsert records", e);
        }
    }

    @Override // com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider, com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider
    public <Record> List<VectorSearchResult<Record>> search(String str, List<Float> list, VectorSearchOptions vectorSearchOptions, VectorStoreRecordDefinition vectorStoreRecordDefinition, VectorStoreRecordMapper<Record, ResultSet> vectorStoreRecordMapper) {
        if (vectorStoreRecordDefinition.getVectorFields().isEmpty()) {
            throw new SKException("No vector fields defined. Cannot perform vector search");
        }
        VectorStoreRecordVectorField vectorStoreRecordVectorField = vectorStoreRecordDefinition.getVectorFields().get(0);
        if (vectorSearchOptions == null) {
            vectorSearchOptions = VectorSearchOptions.createDefault(vectorStoreRecordVectorField.getName());
        }
        VectorStoreRecordVectorField vectorStoreRecordVectorField2 = vectorSearchOptions.getVectorFieldName() == null ? vectorStoreRecordVectorField : (VectorStoreRecordVectorField) vectorStoreRecordDefinition.getField(vectorSearchOptions.getVectorFieldName());
        PostgreSQLVectorIndexKind fromIndexKind = PostgreSQLVectorIndexKind.fromIndexKind(vectorStoreRecordVectorField2.getIndexKind());
        PostgreSQLVectorDistanceFunction fromDistanceFunction = PostgreSQLVectorDistanceFunction.fromDistanceFunction(vectorStoreRecordVectorField2.getDistanceFunction());
        if (fromIndexKind != null && fromDistanceFunction == null) {
            throw new SKException("Distance function is required for vector field: " + vectorStoreRecordVectorField2.getName());
        }
        String buildFilter = SQLVectorStoreRecordCollectionSearchMapping.buildFilter(vectorSearchOptions.getVectorSearchFilter(), vectorStoreRecordDefinition);
        List<Object> filterParameters = SQLVectorStoreRecordCollectionSearchMapping.getFilterParameters(vectorSearchOptions.getVectorSearchFilter());
        String str2 = buildFilter.isEmpty() ? "" : "WHERE " + buildFilter;
        String[] strArr = new String[5];
        strArr[0] = getQueryColumnsFromFields(vectorSearchOptions.isIncludeVectors() ? vectorStoreRecordDefinition.getAllFields() : vectorStoreRecordDefinition.getNonVectorFields());
        strArr[1] = validateSQLidentifier(vectorStoreRecordVectorField2.getEffectiveStorageName());
        strArr[2] = fromDistanceFunction == null ? PostgreSQLVectorDistanceFunction.L2.getOperator() : fromDistanceFunction.getOperator();
        strArr[3] = getCollectionTableName(str);
        strArr[4] = str2;
        String formatQuery = formatQuery("SELECT %s, %s %s ?::vector AS score FROM %s %s ORDER BY score LIMIT ? OFFSET ?", strArr);
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(formatQuery);
                try {
                    int i = 1 + 1;
                    prepareStatement.setString(1, this.objectMapper.writeValueAsString(list));
                    Iterator<Object> it = filterParameters.iterator();
                    while (it.hasNext()) {
                        int i2 = i;
                        i++;
                        prepareStatement.setObject(i2, it.next());
                    }
                    prepareStatement.setInt(i, vectorSearchOptions.getLimit());
                    prepareStatement.setInt(i + 1, vectorSearchOptions.getOffset());
                    ArrayList arrayList = new ArrayList();
                    ResultSet executeQuery = prepareStatement.executeQuery();
                    while (executeQuery.next()) {
                        arrayList.add(new VectorSearchResult(vectorStoreRecordMapper.mapStorageModelToRecord(executeQuery, new GetRecordOptions(vectorSearchOptions.isIncludeVectors())), executeQuery.getDouble("score")));
                    }
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                    return arrayList;
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Throwable th3) {
                if (connection != null) {
                    try {
                        connection.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (SQLException | JsonProcessingException e) {
            throw new SKException("Failed to search records", e);
        }
    }
}
