package cn.smartjavaai.obb.model;

import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import cn.smartjavaai.common.cv.SmartImageFactory;
import cn.smartjavaai.common.entity.DetectionResponse;
import cn.smartjavaai.common.entity.R;
import cn.smartjavaai.common.pool.PredictorFactory;
import cn.smartjavaai.common.utils.ImageUtils;
import cn.smartjavaai.instanceseg.model.InstanceSegModelFactory;
import cn.smartjavaai.obb.config.ObbDetModelConfig;
import cn.smartjavaai.obb.criteria.ObbDetCriteriaFactory;
import cn.smartjavaai.obb.entity.ObbResult;
import cn.smartjavaai.obb.exception.ObbDetException;
import cn.smartjavaai.objectdetection.exception.DetectionException;
import cn.smartjavaai.vision.utils.DetectedObjectsFilter;
import cn.smartjavaai.vision.utils.DetectorUtils;
import cn.smartjavaai.vision.utils.ObbResultFilter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.opencv.core.Mat;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Objects;

/**
 * 旋转框模型
 * @author dwj
 */
@Slf4j
public class CommonObbDetModel implements ObbDetModel {


    private ObbDetModelConfig config;

    private ZooModel<Image, ObbResult> model;

    private GenericObjectPool<Predictor<Image, ObbResult>> predictorPool;

    @Override
    public void loadModel(ObbDetModelConfig config) {
        if(Objects.isNull(config.getModelEnum())){
            throw new DetectionException("未配置模型枚举");
        }
        Criteria<Image, ObbResult> criteria = ObbDetCriteriaFactory.createCriteria(config);
        this.config = config;
        try {
            model = criteria.loadModel();
            // 创建池子：每个线程独享 Predictor
            this.predictorPool = new GenericObjectPool<>(new PredictorFactory<>(model));
            int predictorPoolSize = config.getPredictorPoolSize();
            if(config.getPredictorPoolSize() <= 0){
                predictorPoolSize = Runtime.getRuntime().availableProcessors(); // 默认等于CPU核心数
            }
            predictorPool.setMaxTotal(predictorPoolSize);
            log.debug("当前设备: " + model.getNDManager().getDevice());
            log.debug("当前引擎: " + Engine.getInstance().getEngineName());
            log.debug("模型推理器线程池最大数量: " + predictorPoolSize);
        } catch (IOException | ModelNotFoundException | MalformedModelException e) {
            throw new DetectionException("模型加载失败", e);
        }
    }

    @Override
    public R<DetectionResponse> detect(Image image) {
        ObbResult obbResult = detectCore(image);
        DetectionResponse detectionResponse = DetectorUtils.obbToToDetectionResponse(obbResult);
        return R.ok(detectionResponse);
    }

    /**
     * 模型核心推理方法
     * @param image
     * @return
     */
    @Override
    public ObbResult detectCore(Image image) {
        Predictor<Image, ObbResult> predictor = null;
        try {
            predictor = predictorPool.borrowObject();
            ObbResult obbResult = predictor.predict(image);
            //过滤
            if(Objects.nonNull(obbResult) && CollectionUtils.isNotEmpty(obbResult.getRotatedBoxeList())){
                ObbResultFilter obbResultFilter = new ObbResultFilter(config.getAllowedClasses(), config.getTopK());
                obbResult = obbResultFilter.filter(obbResult);
            }
            return obbResult;
        } catch (Exception e) {
            throw new DetectionException("旋转框错误", e);
        }finally {
            if (predictor != null) {
                try {
                    predictorPool.returnObject(predictor); //归还
                    log.debug("释放资源");
                } catch (Exception e) {
                    log.warn("归还Predictor失败", e);
                    try {
                        predictor.close(); // 归还失败才销毁
                    } catch (Exception ex) {
                        log.error("关闭Predictor失败", ex);
                    }
                }
            }
        }
    }

    @Override
    public R<DetectionResponse> detectAndDraw(Image image) {
        ObbResult obbResult = detectCore(image);
        if(Objects.isNull(obbResult) || CollectionUtils.isEmpty(obbResult.getRotatedBoxeList())){
            return R.fail(R.Status.NO_OBJECT_DETECTED);
        }
        DetectorUtils.drawRectWithText(image, obbResult.getRotatedBoxeList());
        DetectionResponse detectionResponse = DetectorUtils.obbToToDetectionResponse(obbResult);
        detectionResponse.setDrawnImage(image);
        return R.ok(detectionResponse);
    }

    @Override
    public R<DetectionResponse> detectAndDraw(String imagePath, String outputPath) {
        Image img = null;
        try {
            img = SmartImageFactory.getInstance().fromFile(Paths.get(imagePath));
            ObbResult obbResult = detectCore(img);
            if(Objects.isNull(obbResult) || CollectionUtils.isEmpty(obbResult.getRotatedBoxeList())){
                return R.fail(R.Status.NO_OBJECT_DETECTED);
            }
            DetectorUtils.drawRectWithText(img, obbResult.getRotatedBoxeList());
            img.save(Files.newOutputStream(Paths.get(outputPath)), "png");
            DetectionResponse detectionResponse = DetectorUtils.obbToToDetectionResponse(obbResult);
            return R.ok(detectionResponse);
        } catch (IOException e) {
            throw new ObbDetException(e);
        } finally {
            ImageUtils.releaseOpenCVMat(img);
        }
    }

    @Override
    public void close() throws Exception {
        if (fromFactory) {
            ObbDetModelFactory.removeFromCache(config.getModelEnum());
        }
        try {
            if (predictorPool != null) {
                predictorPool.close();
            }
        } catch (Exception e) {
            log.warn("关闭 predictorPool 失败", e);
        }
        try {
            if (model != null) {
                model.close();
            }
        } catch (Exception e) {
            log.warn("关闭 model 失败", e);
        }
    }

    private boolean fromFactory = false;

    @Override
    public void setFromFactory(boolean fromFactory) {
        this.fromFactory = fromFactory;
    }
    public boolean isFromFactory() {
        return fromFactory;
    }
}
