package ai.djl.examples.inference;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.ImageVisualization;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import javax.imageio.ImageIO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/inference/PoseEstimation.class */
public final class PoseEstimation {
    private static final Logger logger = LoggerFactory.getLogger(PoseEstimation.class);

    private PoseEstimation() {
    }

    public static void main(String[] strArr) throws IOException, ModelException, TranslateException {
        logger.info("{}", predict());
    }

    public static Joints predict() throws IOException, ModelException, TranslateException {
        BufferedImage predictPersonInImage = predictPersonInImage(BufferedImageUtils.fromFile(Paths.get("src/test/resources/pose_soccer.png", new String[0])));
        if (predictPersonInImage != null) {
            return predictJointsInPerson(predictPersonInImage);
        }
        logger.warn("No person found in image.");
        return null;
    }

    private static BufferedImage predictPersonInImage(BufferedImage bufferedImage) throws MalformedModelException, ModelNotFoundException, IOException, TranslateException {
        ZooModel loadModel = ModelZoo.loadModel(Criteria.builder().optApplication(Application.CV.OBJECT_DETECTION).setTypes(BufferedImage.class, DetectedObjects.class).optFilter("size", "512").optFilter("backbone", "resnet50").optFilter("flavor", "v1").optFilter("dataset", "voc").optProgress(new ProgressBar()).build());
        Throwable th = null;
        try {
            Predictor newPredictor = loadModel.newPredictor();
            Throwable th2 = null;
            try {
                try {
                    DetectedObjects detectedObjects = (DetectedObjects) newPredictor.predict(bufferedImage);
                    if (newPredictor != null) {
                        if (0 != 0) {
                            try {
                                newPredictor.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            newPredictor.close();
                        }
                    }
                    for (DetectedObjects.DetectedObject detectedObject : detectedObjects.items()) {
                        if ("person".equals(detectedObject.getClassName())) {
                            Rectangle bounds = detectedObject.getBoundingBox().getBounds();
                            int width = bufferedImage.getWidth();
                            int height = bufferedImage.getHeight();
                            return bufferedImage.getSubimage((int) (bounds.getX() * width), (int) (bounds.getY() * height), (int) (bounds.getWidth() * width), (int) (bounds.getHeight() * height));
                        }
                    }
                    return null;
                } finally {
                }
            } catch (Throwable th4) {
                if (newPredictor != null) {
                    if (th2 != null) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (loadModel != null) {
                if (0 != 0) {
                    try {
                        loadModel.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    loadModel.close();
                }
            }
        }
    }

    private static Joints predictJointsInPerson(BufferedImage bufferedImage) throws MalformedModelException, ModelNotFoundException, IOException, TranslateException {
        ZooModel loadModel = ModelZoo.loadModel(Criteria.builder().optApplication(Application.CV.POSE_ESTIMATION).setTypes(BufferedImage.class, Joints.class).optFilter("backbone", "resnet18").optFilter("flavor", "v1b").optFilter("dataset", "imagenet").build());
        Throwable th = null;
        try {
            Predictor newPredictor = loadModel.newPredictor();
            Throwable th2 = null;
            try {
                try {
                    Joints joints = (Joints) newPredictor.predict(bufferedImage);
                    saveJointsImage(bufferedImage, joints);
                    if (newPredictor != null) {
                        if (0 != 0) {
                            try {
                                newPredictor.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            newPredictor.close();
                        }
                    }
                    return joints;
                } finally {
                }
            } catch (Throwable th4) {
                if (newPredictor != null) {
                    if (th2 != null) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (loadModel != null) {
                if (0 != 0) {
                    try {
                        loadModel.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    loadModel.close();
                }
            }
        }
    }

    private static void saveJointsImage(BufferedImage bufferedImage, Joints joints) throws IOException {
        Path path = Paths.get("build/output", new String[0]);
        Files.createDirectories(path, new FileAttribute[0]);
        ImageVisualization.drawJoints(bufferedImage, joints);
        Path resolve = path.resolve("joints.png");
        ImageIO.write(bufferedImage, "png", resolve.toFile());
        logger.info("Pose image has been saved in: {}", resolve);
    }
}
