package dev.langchain4j.store.embedding;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.milvus.MilvusCollectionDescription;
import dev.langchain4j.store.embedding.milvus.MilvusOperationsParams;
import io.milvus.client.MilvusServiceClient;
import io.milvus.param.MetricType;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/store/embedding/Mapper.class */
class Mapper {
    Mapper() {
    }

    public static List<List<Float>> toVectors(List<Embedding> list) {
        return (List) list.stream().map((v0) -> {
            return v0.vectorAsList();
        }).collect(Collectors.toList());
    }

    public static List<String> toScalars(List<TextSegment> list, int i) {
        return list == null || list.isEmpty() ? Generator.generateEmptyScalars(i) : textSegmentsToScalars(list);
    }

    public static List<String> textSegmentsToScalars(List<TextSegment> list) {
        return (List) list.stream().map((v0) -> {
            return v0.text();
        }).collect(Collectors.toList());
    }

    public static List<EmbeddingMatch<TextSegment>> toEmbeddingMatches(MilvusServiceClient milvusServiceClient, SearchResultsWrapper searchResultsWrapper, MilvusCollectionDescription milvusCollectionDescription, MilvusOperationsParams milvusOperationsParams, double d) {
        ArrayList arrayList = new ArrayList();
        Map<String, List<Float>> vectors = getVectors(milvusServiceClient, milvusCollectionDescription, searchResultsWrapper.getFieldWrapper(milvusCollectionDescription.idFieldName()).getFieldData(), milvusOperationsParams);
        for (int i = 0; i < searchResultsWrapper.getRowRecords().size(); i++) {
            String strID = ((SearchResultsWrapper.IDScore) searchResultsWrapper.getIDScore(0).get(i)).getStrID();
            arrayList.add(new EmbeddingMatch(Double.valueOf(((SearchResultsWrapper.IDScore) searchResultsWrapper.getIDScore(0).get(i)).getScore()), strID, Embedding.from(vectors.getOrDefault(strID, Collections.emptyList())), TextSegment.from(String.valueOf(searchResultsWrapper.getFieldData(milvusCollectionDescription.scalarFieldName(), 0).get(i)))));
        }
        return filterByMinSimilarity(arrayList, d, milvusOperationsParams.metricType().name());
    }

    private static List<EmbeddingMatch<TextSegment>> filterByMinSimilarity(List<EmbeddingMatch<TextSegment>> list, double d, String str) {
        return (List) list.stream().filter(getPredicate(str, d)).collect(Collectors.toList());
    }

    private static Predicate<EmbeddingMatch<TextSegment>> getPredicate(String str, double d) {
        Predicate<EmbeddingMatch<TextSegment>> predicate = embeddingMatch -> {
            return embeddingMatch.score().doubleValue() <= d;
        };
        Predicate<EmbeddingMatch<TextSegment>> predicate2 = embeddingMatch2 -> {
            return embeddingMatch2.score().doubleValue() >= d;
        };
        if (MetricType.L2.equals(MetricType.valueOf(str))) {
            return predicate;
        }
        if (MetricType.IP.equals(MetricType.valueOf(str))) {
            return predicate2;
        }
        throw new IllegalArgumentException(String.format("Unsupported metricType: '%s'.%n", str));
    }

    private static Map<String, List<Float>> getVectors(MilvusServiceClient milvusServiceClient, MilvusCollectionDescription milvusCollectionDescription, List<String> list, MilvusOperationsParams milvusOperationsParams) {
        if (!milvusOperationsParams.queryForVectorOnSearch()) {
            return Collections.emptyMap();
        }
        QueryResultsWrapper queryForVectors = CollectionOperationsExecutor.queryForVectors(milvusServiceClient, milvusCollectionDescription, list, milvusOperationsParams.consistencyLevel().name());
        HashMap hashMap = new HashMap();
        for (QueryResultsWrapper.RowRecord rowRecord : queryForVectors.getRowRecords()) {
            hashMap.put(rowRecord.get(milvusCollectionDescription.idFieldName()).toString(), (List) rowRecord.get(milvusCollectionDescription.vectorFieldName()));
        }
        return hashMap;
    }
}
