/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.clip.model;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.ndarray.NDList;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.Translator;
import ai.djl.util.Pair;
import cn.smartjavaai.clip.config.ClipModelConfig;
import cn.smartjavaai.clip.exception.ClipException;
import cn.smartjavaai.clip.model.ClipModel;
import cn.smartjavaai.clip.pool.ClipImagePredictorFactory;
import cn.smartjavaai.clip.pool.ClipImageTextPredictorFactory;
import cn.smartjavaai.clip.pool.ClipTextPredictorFactory;
import cn.smartjavaai.common.cv.SmartImageFactory;
import cn.smartjavaai.common.entity.R;
import cn.smartjavaai.common.enums.SimilarityType;
import cn.smartjavaai.common.utils.DJLCommonUtils;
import cn.smartjavaai.common.utils.ImageUtils;
import cn.smartjavaai.common.utils.SimilarityUtil;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OpenAIClipModel
implements ClipModel {
    private static final Logger log = LoggerFactory.getLogger(OpenAIClipModel.class);
    private ClipModelConfig config;
    private ZooModel<NDList, NDList> model;
    private HuggingFaceTokenizer tokenizer;
    private GenericObjectPool<Predictor<Image, float[]>> imageFeaturePredictorPool;
    private GenericObjectPool<Predictor<String, float[]>> textFeaturePredictorPool;
    private GenericObjectPool<Predictor<Pair<Image, String>, float[]>> imgTextPredictorPool;
    private boolean fromFactory = false;

    @Override
    public void loadModel(ClipModelConfig config) {
        if (Objects.isNull((Object)config)) {
            throw new ClipException("config\u4e3anull");
        }
        if (StringUtils.isBlank((CharSequence)config.getModelPath())) {
            throw new ClipException("modelPath\u4e3a\u7a7a");
        }
        this.config = config;
        try {
            boolean isUrl = DJLCommonUtils.hasSupportedProtocol((String)config.getModelPath());
            Criteria criteria = Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls(isUrl ? config.getModelPath() : null).optModelName("clip.pt").optModelPath(isUrl ? null : Paths.get(config.getModelPath(), new String[0])).optTranslator((Translator)new NoopTranslator()).optEngine("PyTorch").optDevice(Device.cpu()).build();
            this.model = criteria.loadModel();
            Path modelCachePath = this.model.getWrappedModel().getModelPath();
            Path tokenizerPath = modelCachePath.resolve("tokenizer.json");
            this.tokenizer = HuggingFaceTokenizer.newInstance((Path)tokenizerPath);
            this.imageFeaturePredictorPool = new GenericObjectPool((PooledObjectFactory)new ClipImagePredictorFactory((Model)this.model));
            this.textFeaturePredictorPool = new GenericObjectPool((PooledObjectFactory)new ClipTextPredictorFactory((Model)this.model, this.tokenizer));
            this.imgTextPredictorPool = new GenericObjectPool((PooledObjectFactory)new ClipImageTextPredictorFactory((Model)this.model, this.tokenizer));
            int predictorPoolSize = config.getPredictorPoolSize();
            if (config.getPredictorPoolSize() <= 0) {
                predictorPoolSize = Runtime.getRuntime().availableProcessors();
            }
            this.imageFeaturePredictorPool.setMaxTotal(predictorPoolSize);
            this.textFeaturePredictorPool.setMaxTotal(predictorPoolSize);
            this.imgTextPredictorPool.setMaxTotal(predictorPoolSize);
            log.debug("\u5f53\u524d\u8bbe\u5907: " + this.model.getNDManager().getDevice());
            log.debug("\u5f53\u524d\u5f15\u64ce: " + Engine.getInstance().getEngineName());
            log.debug("\u6a21\u578b\u63a8\u7406\u5668\u7ebf\u7a0b\u6c60\u6700\u5927\u6570\u91cf: " + predictorPoolSize);
        }
        catch (MalformedModelException | ModelNotFoundException | IOException e) {
            throw new ClipException("\u6a21\u578b\u52a0\u8f7d\u5931\u8d25", e);
        }
    }

    @Override
    public R<float[]> extractImageFeatures(Image image) {
        Predictor predictor = null;
        try {
            predictor = (Predictor)this.imageFeaturePredictorPool.borrowObject();
            R r = R.ok((Object)predictor.predict((Object)image));
            return r;
        }
        catch (Exception e) {
            throw new ClipException("\u7279\u5f81\u63d0\u53d6\u9519\u8bef", e);
        }
        finally {
            if (predictor != null) {
                try {
                    this.imageFeaturePredictorPool.returnObject((Object)predictor);
                }
                catch (Exception e) {
                    log.warn("\u5f52\u8fd8Predictor\u5931\u8d25", (Throwable)e);
                    try {
                        predictor.close();
                    }
                    catch (Exception ex) {
                        log.error("\u5173\u95edPredictor\u5931\u8d25", (Throwable)ex);
                    }
                }
            }
        }
    }

    @Override
    public R<float[]> extractImageFeatures(String imagePath) {
        Image image = null;
        try {
            image = SmartImageFactory.getInstance().fromFile(imagePath);
            R<float[]> r = this.extractImageFeatures(image);
            return r;
        }
        catch (IOException e) {
            throw new ClipException(e);
        }
        finally {
            ImageUtils.releaseOpenCVMat((Image)image);
        }
    }

    @Override
    public R<float[]> extractTextFeatures(String inputs) {
        Predictor predictor = null;
        try {
            predictor = (Predictor)this.textFeaturePredictorPool.borrowObject();
            R r = R.ok((Object)predictor.predict((Object)inputs));
            return r;
        }
        catch (Exception e) {
            throw new ClipException("\u7279\u5f81\u63d0\u53d6\u9519\u8bef", e);
        }
        finally {
            if (predictor != null) {
                try {
                    this.textFeaturePredictorPool.returnObject((Object)predictor);
                }
                catch (Exception e) {
                    log.warn("\u5f52\u8fd8Predictor\u5931\u8d25", (Throwable)e);
                    try {
                        predictor.close();
                    }
                    catch (Exception ex) {
                        log.error("\u5173\u95edPredictor\u5931\u8d25", (Throwable)ex);
                    }
                }
            }
        }
    }

    @Override
    public R<Float> compareTextAndImage(Image image, String text) {
        Predictor predictor = null;
        try {
            predictor = (Predictor)this.imgTextPredictorPool.borrowObject();
            float[] imageFeatures = (float[])predictor.predict((Object)new Pair((Object)image, (Object)text));
            if (imageFeatures == null || imageFeatures.length == 0) {
                R r = R.fail((Integer)R.Status.Unknown.getCode(), (String)"\u7279\u5f81\u4e3a\u7a7a");
                return r;
            }
            R r = R.ok((Object)Float.valueOf(imageFeatures[0]));
            return r;
        }
        catch (Exception e) {
            throw new ClipException("\u7279\u5f81\u63d0\u53d6\u9519\u8bef", e);
        }
        finally {
            if (predictor != null) {
                try {
                    this.imgTextPredictorPool.returnObject((Object)predictor);
                }
                catch (Exception e) {
                    log.warn("\u5f52\u8fd8Predictor\u5931\u8d25", (Throwable)e);
                    try {
                        predictor.close();
                    }
                    catch (Exception ex) {
                        log.error("\u5173\u95edPredictor\u5931\u8d25", (Throwable)ex);
                    }
                }
            }
        }
    }

    @Override
    public R<Float> compareFeatures(float[] feature1, float[] feature2, float scale) {
        float similarity = SimilarityUtil.calculate((float[])feature1, (float[])feature2, (SimilarityType)SimilarityType.COSINE, (boolean)false);
        return R.ok((Object)Float.valueOf(similarity * scale));
    }

    @Override
    public R<Float> compareImage(Image image1, Image image2) {
        return this.compareImage(image1, image2, 1.0f);
    }

    @Override
    public R<Float> compareImage(Image image1, Image image2, float scale) {
        R<float[]> features1 = this.extractImageFeatures(image1);
        R<float[]> features2 = this.extractImageFeatures(image2);
        if (!features1.isSuccess()) {
            return R.fail((Integer)features1.getCode(), (String)features1.getMessage());
        }
        if (!features2.isSuccess()) {
            return R.fail((Integer)features2.getCode(), (String)features2.getMessage());
        }
        return this.compareFeatures((float[])features1.getData(), (float[])features2.getData(), scale);
    }

    @Override
    public R<Float> compareImage(String imagePath1, String imagePath2) {
        return this.compareImage(imagePath1, imagePath2, 1.0f);
    }

    @Override
    public R<Float> compareImage(String imagePath1, String imagePath2, float scale) {
        R<Float> r;
        Image image1 = null;
        Image image2 = null;
        try {
            image1 = SmartImageFactory.getInstance().fromFile(imagePath1);
            image2 = SmartImageFactory.getInstance().fromFile(imagePath2);
            r = this.compareImage(image1, image2, scale);
        }
        catch (IOException e) {
            try {
                throw new ClipException(e);
            }
            catch (Throwable throwable) {
                ImageUtils.releaseOpenCVMat((Image)image1);
                ImageUtils.releaseOpenCVMat(image2);
                throw throwable;
            }
        }
        ImageUtils.releaseOpenCVMat((Image)image1);
        ImageUtils.releaseOpenCVMat((Image)image2);
        return r;
    }

    @Override
    public R<Float> compareText(String input1, String input2) {
        return this.compareText(input1, input2, 1.0f);
    }

    @Override
    public R<Float> compareText(String input1, String input2, float scale) {
        R<float[]> features1 = this.extractTextFeatures(input1);
        R<float[]> features2 = this.extractTextFeatures(input2);
        if (!features1.isSuccess()) {
            return R.fail((Integer)features1.getCode(), (String)features1.getMessage());
        }
        if (!features2.isSuccess()) {
            return R.fail((Integer)features2.getCode(), (String)features2.getMessage());
        }
        return this.compareFeatures((float[])features1.getData(), (float[])features2.getData(), scale);
    }

    @Override
    public void close() throws Exception {
        try {
            if (this.imageFeaturePredictorPool != null) {
                this.imageFeaturePredictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed predictorPool \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.textFeaturePredictorPool != null) {
                this.textFeaturePredictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed predictorPool \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.imgTextPredictorPool != null) {
                this.imgTextPredictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed predictorPool \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.model != null) {
                this.model.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed model \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.tokenizer != null) {
                this.tokenizer.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed tokenizer \u5931\u8d25", (Throwable)e);
        }
    }

    public boolean isFromFactory() {
        return this.fromFactory;
    }

    @Override
    public void setFromFactory(boolean fromFactory) {
        this.fromFactory = fromFactory;
    }
}

