/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.vector.core;

import cn.hutool.core.util.IdUtil;
import cn.smartjavaai.common.config.Config;
import cn.smartjavaai.common.entity.face.FaceSearchResult;
import cn.smartjavaai.common.enums.SimilarityType;
import cn.smartjavaai.common.utils.SimilarityUtil;
import cn.smartjavaai.face.dao.FaceDao;
import cn.smartjavaai.face.entity.FaceSearchParams;
import cn.smartjavaai.face.vector.config.SQLiteConfig;
import cn.smartjavaai.face.vector.core.VectorDBClient;
import cn.smartjavaai.face.vector.entity.FaceVector;
import cn.smartjavaai.face.vector.exception.VectorDBException;
import java.io.File;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SQLiteClient
implements VectorDBClient {
    private static final Logger log = LoggerFactory.getLogger(SQLiteClient.class);
    private final FaceDao faceDao;
    private final ConcurrentHashMap<String, FaceVector> memoryIndex = new ConcurrentHashMap();
    private int featureDimension;
    private final ExecutorService executor = Executors.newFixedThreadPool(4);
    private SQLiteConfig config;
    private boolean isInit;

    public SQLiteClient(SQLiteConfig config) {
        this.config = config;
        String dbPath = config.getDbPath();
        if (StringUtils.isBlank((CharSequence)config.getDbPath())) {
            dbPath = Config.getCachePath() + File.separator + "face.db";
            log.debug("\u4f7f\u7528\u9ed8\u8ba4SQLite\u4eba\u8138\u5e93\u8def\u5f84: {}", (Object)dbPath);
        }
        this.faceDao = FaceDao.getInstance(dbPath);
    }

    @Override
    public void initialize() {
        try {
            this.loadAllFeaturesToMemory();
            this.isInit = true;
            log.debug("SQLiteVectorDB initialized with {} faces", (Object)this.memoryIndex.size());
        }
        catch (Exception e) {
            throw new VectorDBException("\u521d\u59cb\u5316\u5931\u8d25", e);
        }
    }

    @Override
    public void createCollection(String collectionName, int dimension) {
        this.featureDimension = dimension;
        log.debug("\u7279\u5f81\u7ef4\u5ea6\u8bbe\u7f6e\u4e3a: {}", (Object)dimension);
    }

    @Override
    public void dropCollection(String collectionName) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        this.clearAllData();
        log.warn("\u6240\u6709\u6570\u636e\u5df2\u88ab\u6e05\u7a7a");
    }

    @Override
    public boolean hasCollection(String collectionName) {
        throw new UnsupportedOperationException("Sqlite \u4e0d\u652f\u6301\u6b64\u64cd\u4f5c");
    }

    @Override
    public String insert(FaceVector faceVector) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        return this.insertBatch(Collections.singletonList(faceVector)).get(0);
    }

    @Override
    public void upsert(FaceVector faceVector) {
        this.insert(faceVector);
    }

    @Override
    public List<String> insertBatch(List<FaceVector> faceVectors) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        ArrayList<String> ids = new ArrayList<String>();
        try {
            for (FaceVector faceVector : faceVectors) {
                String id = faceVector.getId() != null ? faceVector.getId() : IdUtil.simpleUUID();
                faceVector.setId(id);
                this.faceDao.insertOrUpdate(faceVector);
                this.addToMemoryIndex(faceVector);
                ids.add(id);
            }
            log.debug("\u63d2\u5165\u4e86 {} \u4e2a\u4eba\u8138\u5411\u91cf", (Object)faceVectors.size());
            return ids;
        }
        catch (Exception e) {
            throw new VectorDBException("\u6279\u91cf\u63d2\u5165\u5931\u8d25", e);
        }
    }

    @Override
    public void delete(String id) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        this.deleteBatch(Collections.singletonList(id));
    }

    @Override
    public void deleteBatch(List<String> ids) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        try {
            boolean isSuccess = this.faceDao.deleteFace(ids.toArray(new String[0]));
            ids.forEach(this.memoryIndex::remove);
            if (!isSuccess) {
                throw new VectorDBException("\u5220\u9664\u5931\u8d25");
            }
        }
        catch (Exception e) {
            throw new VectorDBException("\u6279\u91cf\u5220\u9664\u5931\u8d25", e);
        }
    }

    @Override
    public List<FaceSearchResult> search(float[] queryVector, FaceSearchParams faceSearchParams) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        if (this.memoryIndex.isEmpty()) {
            return Collections.emptyList();
        }
        List futures = this.memoryIndex.values().stream().map(vector -> CompletableFuture.supplyAsync(() -> {
            float similarity = SimilarityUtil.calculate((float[])queryVector, (float[])vector.getVector(), (SimilarityType)this.config.getSimilarityType(), (boolean)faceSearchParams.getNormalizeSimilarity());
            return similarity >= faceSearchParams.getThreshold().floatValue() ? new FaceSearchResult(vector.getId(), similarity, vector.getMetadata()) : null;
        }, this.executor)).collect(Collectors.toList());
        List allResults = futures.stream().map(CompletableFuture::join).filter(Objects::nonNull).collect(Collectors.toList());
        return allResults.stream().sorted(Comparator.comparingDouble(FaceSearchResult::getSimilarity).reversed()).limit(faceSearchParams.getTopK().intValue()).collect(Collectors.toList());
    }

    @Override
    public long count(String collectionName) {
        return this.memoryIndex.size();
    }

    @Override
    public void close() {
        this.executor.shutdown();
        try {
            if (!this.executor.awaitTermination(5L, TimeUnit.SECONDS)) {
                this.executor.shutdownNow();
            }
        }
        catch (InterruptedException e) {
            this.executor.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    @Override
    public FaceVector getFaceInfoById(String id) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        FaceVector faceVector = this.memoryIndex.get(id);
        if (faceVector == null) {
            try {
                faceVector = this.faceDao.findById(id);
            }
            catch (ClassNotFoundException | SQLException e) {
                throw new VectorDBException("SQLite\u67e5\u8be2\u5f02\u5e38", e);
            }
        }
        return faceVector;
    }

    @Override
    public List<FaceVector> listFaces(long pageNum, long pageSize) {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        if (pageNum < 1L || pageSize < 1L) {
            throw new IllegalArgumentException("pageNum\u548cpageSize\u5fc5\u987b\u5927\u4e8e0");
        }
        try {
            return this.faceDao.findFace((int)pageNum, (int)pageSize);
        }
        catch (Exception e) {
            throw new VectorDBException("\u5206\u9875\u67e5\u8be2\u5931\u8d25", e);
        }
    }

    private void loadAllFeaturesToMemory() {
        try {
            List<FaceVector> batch;
            int pageSize = 1000;
            int page = 1;
            while (!CollectionUtils.isEmpty(batch = this.faceDao.findFace(page, pageSize))) {
                for (FaceVector vector : batch) {
                    this.addToMemoryIndex(vector);
                }
                ++page;
            }
            log.debug("\u4ece\u6570\u636e\u5e93\u52a0\u8f7d\u4e86 {} \u4e2a\u7279\u5f81\u5411\u91cf\u5230\u5185\u5b58", (Object)this.memoryIndex.size());
        }
        catch (Exception e) {
            throw new VectorDBException("\u52a0\u8f7d\u7279\u5f81\u5230\u5185\u5b58\u5931\u8d25", e);
        }
    }

    private void addToMemoryIndex(FaceVector faceVector) {
        this.memoryIndex.put(faceVector.getId(), faceVector);
    }

    private void clearAllData() {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        try {
            this.faceDao.deleteAll();
            this.memoryIndex.clear();
        }
        catch (Exception e) {
            log.error("\u6e05\u7a7a\u6570\u636e\u5e93\u5931\u8d25", (Throwable)e);
        }
    }

    @Override
    public void loadFaceFeatures() {
        this.loadAllFeaturesToMemory();
        this.isInit = true;
        log.debug("SQLiteVectorDB load success {} faces", (Object)this.memoryIndex.size());
    }

    @Override
    public void releaseFaceFeatures() {
        if (!this.isInit) {
            throw new VectorDBException("\u4eba\u8138\u5e93\u672a\u52a0\u8f7d\u5b8c\u6bd5");
        }
        this.memoryIndex.clear();
        this.isInit = false;
    }
}

