/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.store.embedding.clickhouse;

import com.clickhouse.client.api.Client;
import com.clickhouse.client.api.data_formats.internal.BinaryStreamReader;
import com.clickhouse.client.api.insert.InsertResponse;
import com.clickhouse.client.api.metrics.ServerMetrics;
import com.clickhouse.client.api.query.GenericRecord;
import com.clickhouse.client.api.query.Records;
import com.clickhouse.data.ClickHouseDataType;
import com.clickhouse.data.ClickHouseFormat;
import dev.langchain4j.community.store.embedding.clickhouse.ClickHouseJsonUtils;
import dev.langchain4j.community.store.embedding.clickhouse.ClickHouseMetadataFilterMapper;
import dev.langchain4j.community.store.embedding.clickhouse.ClickHouseSettings;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ClickHouseEmbeddingStore
implements EmbeddingStore<TextSegment>,
AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(ClickHouseEmbeddingStore.class);
    private final Client client;
    private final ClickHouseSettings settings;
    private final ClickHouseMetadataFilterMapper filterMapper;

    public ClickHouseEmbeddingStore(Client client, ClickHouseSettings settings) {
        this.settings = (ClickHouseSettings)ValidationUtils.ensureNotNull((Object)settings, (String)"settings");
        this.filterMapper = new ClickHouseMetadataFilterMapper(settings.getColumnMap(), settings.getMetadataTypeMap());
        this.client = Optional.ofNullable(client).orElse(new Client.Builder().addEndpoint(settings.getUrl()).setUsername(settings.getUsername()).setPassword(settings.getPassword()).serverSetting("allow_experimental_vector_similarity_index", "1").build());
        this.createDatabase();
        this.createTable();
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public void close() throws Exception {
        this.client.close();
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAll(ids, embeddings, null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAll(ids, embeddings, embedded);
        return ids;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        EmbeddingSearchResult embeddingSearchResult;
        block8: {
            Records records = (Records)this.client.queryRecords(this.buildQuerySql(request)).get(this.settings.getTimeout(), TimeUnit.MILLISECONDS);
            try {
                ArrayList relevantList = new ArrayList();
                records.forEach(r -> relevantList.add(this.toEmbeddingMatch((GenericRecord)r)));
                embeddingSearchResult = new EmbeddingSearchResult(relevantList.stream().filter(relevant -> relevant.score() >= request.minScore()).collect(Collectors.toList()));
                if (records == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (records != null) {
                        try {
                            records.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            records.close();
        }
        return embeddingSearchResult;
    }

    public void removeAll(Collection<String> ids) {
        ValidationUtils.ensureNotEmpty(ids, (String)"ids");
        this.removeAll(MetadataFilterBuilder.metadataKey((String)this.settings.getColumnMapping("id")).isIn(ids));
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull((Object)filter, (String)"filter");
        String whereClause = "WHERE " + this.filterMapper.map(filter);
        this.client.execute(String.format("DELETE FROM %s.%s %s", this.settings.getDatabase(), this.settings.getTable(), whereClause));
    }

    public void removeAll() {
        this.client.execute(String.format("TRUNCATE TABLE IF EXISTS %s.%s", this.settings.getDatabase(), this.settings.getTable()));
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("ClickhouseEmbeddingStore don't add empty embeddings to ClickHouse");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        int length = ids.size();
        ArrayList<Map<String, Object>> dataList = new ArrayList<Map<String, Object>>();
        for (int i = 0; i < length; ++i) {
            dataList.add(this.toInsertData(ids.get(i), embeddings.get(i), embedded == null ? null : embedded.get(i)));
        }
        String json = ClickHouseJsonUtils.toJson(dataList);
        ByteArrayInputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8));
        try (InsertResponse response = (InsertResponse)this.client.insert(this.settings.getTable(), (InputStream)inputStream, ClickHouseFormat.JSON).get(this.settings.getTimeout(), TimeUnit.MILLISECONDS);){
            if (log.isDebugEnabled()) {
                log.debug("Insert finished: {} rows written", (Object)response.getMetrics().getMetric(ServerMetrics.NUM_ROWS_WRITTEN).getLong());
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void createDatabase() {
        this.client.execute(String.format("CREATE DATABASE IF NOT EXISTS %s", this.settings.getDatabase()));
    }

    private void createTable() {
        String TEMPLATE = "%s Nullable(%s)";
        ArrayList<String> metadataColumns = new ArrayList<String>();
        if (this.settings.containsMetadata()) {
            for (Map.Entry<String, ClickHouseDataType> entry : this.settings.getMetadataTypeMap().entrySet()) {
                metadataColumns.add(String.format("%s Nullable(%s)", entry.getKey(), entry.getValue().name()));
            }
        }
        String metadataCreateSql = metadataColumns.isEmpty() ? "" : String.join((CharSequence)",", metadataColumns) + ", ";
        String createTableSql = String.format("CREATE TABLE IF NOT EXISTS %s.%s(%s String,%s Nullable(String),%s Array(Float64),%sCONSTRAINT cons_vec_len CHECK length(%s) = %d,INDEX vec_idx %s TYPE vector_similarity('hnsw', 'cosineDistance', %d) GRANULARITY 1000) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 8192", this.settings.getDatabase(), this.settings.getTable(), this.settings.getColumnMapping("id"), this.settings.getColumnMapping("text"), this.settings.getColumnMapping("embedding"), metadataCreateSql, this.settings.getColumnMapping("embedding"), this.settings.getDimension(), this.settings.getColumnMapping("embedding"), this.settings.getDimension());
        this.client.execute(createTableSql);
    }

    private String buildQuerySql(EmbeddingSearchRequest request) {
        Embedding refEmbedding = request.queryEmbedding();
        int maxResults = request.maxResults();
        Filter filter = request.filter();
        String whereClause = filter == null ? "" : String.format("WHERE %s", this.filterMapper.map(filter));
        String refEmbeddingStr = "[" + refEmbedding.vectorAsList().stream().map(String::valueOf).collect(Collectors.joining(",")) + "]";
        ArrayList<String> queryColumnList = new ArrayList<String>(Arrays.asList(this.settings.getColumnMapping("id"), this.settings.getColumnMapping("text"), this.settings.getColumnMapping("embedding")));
        if (this.settings.containsMetadata()) {
            queryColumnList.addAll(this.settings.getMetadataTypeMap().keySet());
        }
        return String.format("WITH %s AS reference_vector SELECT %s, dist FROM %s.%s %s ORDER BY cosineDistance(%s, reference_vector) AS %s ASC LIMIT %d", refEmbeddingStr, String.join((CharSequence)",", queryColumnList), this.settings.getDatabase(), this.settings.getTable(), whereClause, this.settings.getColumnMapping("embedding"), "dist", maxResults);
    }

    private EmbeddingMatch<TextSegment> toEmbeddingMatch(GenericRecord r) {
        String id = r.getString(this.settings.getColumnMapping("id"));
        String text = r.getString(this.settings.getColumnMapping("text"));
        List doubleEmbedding = ((BinaryStreamReader.ArrayValue)r.getObject("embedding")).asList();
        float[] embedding = new float[doubleEmbedding.size()];
        for (int i = 0; i < doubleEmbedding.size(); ++i) {
            embedding[i] = ((Double)doubleEmbedding.get(i)).floatValue();
        }
        TextSegment textSegment = null;
        if (text != null) {
            Metadata metadata = new Metadata();
            if (this.settings.containsMetadata()) {
                HashMap<String, Object> searchedMetadata = new HashMap<String, Object>();
                for (String metadataKey : this.settings.getMetadataTypeMap().keySet()) {
                    Object val = r.getObject(metadataKey);
                    if (val == null) continue;
                    searchedMetadata.put(metadataKey, val);
                }
                metadata = Metadata.from(searchedMetadata);
            }
            textSegment = TextSegment.from((String)text, (Metadata)metadata);
        }
        double cosineDistance = r.getDouble("dist");
        return new EmbeddingMatch(Double.valueOf(RelevanceScore.fromCosineSimilarity((double)(1.0 - cosineDistance))), id, Embedding.from((float[])embedding), (Object)textSegment);
    }

    private Map<String, Object> toInsertData(String id, Embedding embedding, TextSegment segment) {
        HashMap<String, Object> data = new HashMap<String, Object>(4);
        Float[] insertEmbedding = embedding.vectorAsList().toArray(new Float[0]);
        Map metadata = segment == null ? null : segment.metadata().toMap();
        data.put(this.settings.getColumnMapping("id"), id);
        data.put(this.settings.getColumnMapping("embedding"), insertEmbedding);
        data.put(this.settings.getColumnMapping("text"), segment == null ? null : segment.text());
        if (this.settings.containsMetadata()) {
            Map<String, ClickHouseDataType> meatadataColumnMap = this.settings.getMetadataTypeMap();
            for (String key : meatadataColumnMap.keySet()) {
                data.put(key, Optional.ofNullable(metadata).map(m -> m.get(key)).orElse(null));
            }
        }
        return data;
    }

    public static class Builder {
        private Client client;
        private ClickHouseSettings settings;

        public Builder client(Client client) {
            this.client = client;
            return this;
        }

        public Builder settings(ClickHouseSettings settings) {
            this.settings = settings;
            return this;
        }

        public ClickHouseEmbeddingStore build() {
            return new ClickHouseEmbeddingStore(this.client, this.settings);
        }
    }
}

