/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facerec.criteria;

import ai.djl.Device;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.face.config.FaceRecConfig;
import cn.smartjavaai.face.model.facerec.FaceRecPreprocessConfig;
import cn.smartjavaai.face.model.facerec.translator.CommonFaceRecTranslator;
import java.nio.file.Paths;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;

public class FaceRecCriteriaFactory {
    public static Criteria<Image, float[]> createCriteria(FaceRecConfig config) {
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu((int)config.getGpuId());
        }
        Translator<Image, float[]> translator = FaceRecCriteriaFactory.getFaceRecTranslator(config);
        Criteria criteria = Criteria.builder().setTypes(Image.class, float[].class).optModelUrls(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? null : "https://resources.djl.ai/test-models/pytorch/face_feature.zip").optModelPath(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? Paths.get(config.getModelPath(), new String[0]) : null).optTranslator(translator).optDevice(device).optEngine(config.getModelEnum().getEngine()).optProgress((Progress)new ProgressBar()).build();
        return criteria;
    }

    public static Translator<Image, float[]> getFaceRecTranslator(FaceRecConfig config) {
        FaceRecPreprocessConfig preprocessConfig = new FaceRecPreprocessConfig.Builder().inputSize(config.getModelEnum().getInputWidth(), config.getModelEnum().getInputHeight()).build();
        switch (config.getModelEnum()) {
            case VGG_FACE: {
                preprocessConfig = new FaceRecPreprocessConfig.Builder().inputSize(config.getModelEnum().getInputWidth(), config.getModelEnum().getInputHeight()).usePipeline(false).normalize(false).build();
            }
        }
        return new CommonFaceRecTranslator(preprocessConfig);
    }
}

