package ai.djl.examples.inference;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.examples.inference.benchmark.util.Arguments;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.Properties;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/inference/LoadModel.class */
public final class LoadModel {
    private static final Logger logger = LoggerFactory.getLogger(LoadModel.class);

    private LoadModel() {
    }

    public static void main(String[] strArr) throws IOException, ModelException, TranslateException {
        Options options = Arguments.getOptions();
        try {
            logger.info("{}", predict(new Arguments(new DefaultParser().parse(options, strArr, (Properties) null, false))));
        } catch (ParseException e) {
            HelpFormatter helpFormatter = new HelpFormatter();
            helpFormatter.setLeftPadding(1);
            helpFormatter.setWidth(120);
            helpFormatter.printHelp(e.getMessage(), options);
        }
    }

    public static Classifications predict(Arguments arguments) throws IOException, ModelException, TranslateException {
        BufferedImage fromFile = BufferedImageUtils.fromFile(arguments.getImageFile());
        String artifactId = arguments.getArtifactId();
        Criteria.Builder optProgress = Criteria.builder().optApplication(Application.CV.IMAGE_CLASSIFICATION).setTypes(BufferedImage.class, Classifications.class).optArtifactId(artifactId).optFilters(arguments.getCriteria()).optProgress(new ProgressBar());
        if (artifactId.startsWith("ai.djl.localmodelzoo")) {
            optProgress.optTranslator(getTranslator());
        }
        ZooModel loadModel = ModelZoo.loadModel(optProgress.build());
        Throwable th = null;
        try {
            Predictor newPredictor = loadModel.newPredictor();
            Throwable th2 = null;
            try {
                Classifications classifications = (Classifications) newPredictor.predict(fromFile);
                if (newPredictor != null) {
                    if (0 != 0) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                return classifications;
            } catch (Throwable th4) {
                if (newPredictor != null) {
                    if (0 != 0) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (loadModel != null) {
                if (0 != 0) {
                    try {
                        loadModel.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    loadModel.close();
                }
            }
        }
    }

    private static Translator<BufferedImage, Classifications> getTranslator() {
        Pipeline pipeline = new Pipeline();
        pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
        return ImageClassificationTranslator.builder().setPipeline(pipeline).setSynsetArtifactName("synset.txt").build();
    }
}
