/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.cls.criteria;

import ai.djl.Device;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import cn.smartjavaai.action.exception.ActionException;
import cn.smartjavaai.cls.config.ClsModelConfig;
import cn.smartjavaai.cls.enums.ClsModelEnum;
import cn.smartjavaai.cls.translator.YoloClsTranslator;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.common.utils.DJLCommonUtils;
import java.nio.file.Paths;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;

public class ClsCriteriaFactory {
    public static Criteria<Image, Classifications> createCriteria(ClsModelConfig config) {
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu((int)config.getGpuId());
        }
        Translator<Image, Classifications> translator = ClsCriteriaFactory.getTranslator(config);
        if (StringUtils.isBlank((CharSequence)config.getModelPath())) {
            throw new ActionException("\u8bf7\u6307\u5b9a\u6a21\u578b\u8def\u5f84");
        }
        boolean isUrl = DJLCommonUtils.hasSupportedProtocol((String)config.getModelPath());
        Criteria criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls(isUrl ? config.getModelPath() : null).optModelPath(isUrl ? null : Paths.get(config.getModelPath(), new String[0])).optTranslator(translator).optDevice(device).optProgress((Progress)new ProgressBar()).optEngine(config.getModelEnum().getEngine()).build();
        return criteria;
    }

    public static Translator<Image, Classifications> getTranslator(ClsModelConfig config) {
        YoloClsTranslator translator = null;
        if (config.getModelEnum() == ClsModelEnum.YOLOV11 || config.getModelEnum() == ClsModelEnum.YOLOV8) {
            YoloClsTranslator.Builder builder = YoloClsTranslator.builder().optSynsetArtifactName("synset.txt");
            translator = builder.build();
        }
        return translator;
    }
}

