/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.translator;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.CenterFit;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Pad;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ResizeShort;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import cn.smartjavaai.common.utils.LetterBoxUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;

public class YoloV5FaceTranslator
implements Translator<Image, DetectedObjects> {
    private int maxBoxes;
    private YoloOutputType yoloOutputLayerType;
    private float nmsThreshold;
    protected float threshold;
    protected List<String> classes;
    protected boolean applyRatio;
    protected boolean removePadding;
    protected Pipeline pipeline;
    private Image.Flag flag;
    private Batchifier batchifier;
    protected int width;
    protected int height;

    protected YoloV5FaceTranslator(Builder builder) {
        this.yoloOutputLayerType = builder.outputType;
        this.nmsThreshold = builder.nmsThreshold;
        this.maxBoxes = builder.maxBox;
        this.threshold = builder.threshold;
        this.applyRatio = builder.applyRatio;
        this.removePadding = builder.removePadding;
        this.flag = builder.flag;
        this.pipeline = builder.pipeline;
        this.batchifier = builder.batchifier;
        this.width = builder.width;
        this.height = builder.height;
        this.classes = Arrays.asList("face");
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
        NDArray array = input.toNDArray(ctx.getNDManager(), this.flag);
        LetterBoxUtils.ResizeResult letterBoxResult = LetterBoxUtils.letterbox((NDManager)ctx.getNDManager(), (NDArray)array, (int)this.width, (int)this.height, (float)114.0f, (LetterBoxUtils.PaddingPosition)LetterBoxUtils.PaddingPosition.CENTER);
        array = letterBoxResult.image;
        ctx.setAttachment("width", (Object)input.getWidth());
        ctx.setAttachment("height", (Object)input.getHeight());
        ctx.setAttachment("processedWidth", (Object)this.width);
        ctx.setAttachment("processedHeight", (Object)this.height);
        ctx.setAttachment("scale", (Object)Float.valueOf(letterBoxResult.r));
        array = array.toType(DataType.FLOAT32, false).div((Number)Float.valueOf(255.0f));
        array = array.transpose(new int[]{2, 0, 1});
        return new NDList(new NDArray[]{array});
    }

    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws Exception {
        int imageWidth = (Integer)ctx.getAttachment("width");
        int imageHeight = (Integer)ctx.getAttachment("height");
        float scale = ((Float)ctx.getAttachment("scale")).floatValue();
        switch (this.yoloOutputLayerType) {
            case DETECT: {
                return this.processFromDetectOutput();
            }
            case AUTO: {
                if (((NDArray)list.get(0)).getShape().dimension() > 2) {
                    return this.processFromDetectOutput();
                }
                return this.processFromBoxOutput(imageWidth, imageHeight, list, scale);
            }
        }
        return this.processFromBoxOutput(imageWidth, imageHeight, list, scale);
    }

    protected DetectedObjects processFromBoxOutput(int imageWidth, int imageHeight, NDList list, float scale) {
        float[] flattened = ((NDArray)list.get(0)).toFloatArray();
        int sizeClasses = this.classes.size();
        int stride = 15 + sizeClasses;
        int size = flattened.length / stride;
        ArrayList<Landmark> boxes = new ArrayList<Landmark>();
        ArrayList<Float> scores = new ArrayList<Float>();
        ArrayList<Integer> classIds = new ArrayList<Integer>();
        for (int i = 0; i < size; ++i) {
            int indexBase = i * stride;
            float maxClass = 0.0f;
            int maxIndex = 0;
            float score = flattened[indexBase + 4];
            if (!(score > this.threshold)) continue;
            float xPos = flattened[indexBase];
            float yPos = flattened[indexBase + 1];
            float w = flattened[indexBase + 2];
            float h = flattened[indexBase + 3];
            ArrayList<Point> keypoints = new ArrayList<Point>();
            keypoints.add(new Point((double)flattened[indexBase + 5], (double)flattened[indexBase + 6]));
            keypoints.add(new Point((double)flattened[indexBase + 7], (double)flattened[indexBase + 8]));
            keypoints.add(new Point((double)flattened[indexBase + 9], (double)flattened[indexBase + 10]));
            keypoints.add(new Point((double)flattened[indexBase + 11], (double)flattened[indexBase + 12]));
            keypoints.add(new Point((double)flattened[indexBase + 13], (double)flattened[indexBase + 14]));
            Landmark rect = new Landmark((double)Math.max(0.0f, xPos - w / 2.0f), (double)Math.max(0.0f, yPos - h / 2.0f), (double)w, (double)h, keypoints);
            boxes.add(rect);
            scores.add(Float.valueOf(score));
            classIds.add(maxIndex);
        }
        return this.nms(imageWidth, imageHeight, boxes, classIds, scores, scale);
    }

    private DetectedObjects processFromDetectOutput() {
        throw new UnsupportedOperationException("detect layer output is not supported yet, check correct YoloV5 export format");
    }

    protected DetectedObjects nms(int imageWidth, int imageHeight, List<Landmark> boxes, List<Integer> classIds, List<Float> scores, float scale) {
        ArrayList<String> retClasses = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<Landmark> retBB = new ArrayList<Landmark>();
        for (int classId = 0; classId < this.classes.size(); ++classId) {
            ArrayList<Landmark> r = new ArrayList<Landmark>();
            ArrayList<Double> s = new ArrayList<Double>();
            ArrayList<Integer> map = new ArrayList<Integer>();
            for (int j = 0; j < classIds.size(); ++j) {
                if (classIds.get(j) != classId) continue;
                r.add(boxes.get(j));
                s.add(scores.get(j).doubleValue());
                map.add(j);
            }
            if (r.isEmpty()) continue;
            List nms = Rectangle.nms(r, s, (float)this.nmsThreshold);
            Iterator iterator = nms.iterator();
            while (iterator.hasNext()) {
                int index = (Integer)iterator.next();
                int pos = (Integer)map.get(index);
                int id = classIds.get(pos);
                int percent = (int)Math.round(scores.get(pos).doubleValue() * 100.0);
                String className = "face " + percent + "%";
                retClasses.add(className);
                retProbs.add(scores.get(pos).doubleValue());
                Landmark rect = boxes.get(pos);
                ArrayList keypoints = new ArrayList();
                rect = LetterBoxUtils.restoreBox((Landmark)rect, (float)scale, (int)imageWidth, (int)imageHeight, (int)this.width, (int)this.height, (boolean)false);
                retBB.add(rect);
            }
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }

    public static enum YoloOutputType {
        BOX,
        DETECT,
        AUTO;

    }

    public static class Builder {
        private int maxBox = 8400;
        YoloOutputType outputType = YoloOutputType.AUTO;
        float nmsThreshold = 0.4f;
        protected float threshold = 0.2f;
        protected boolean applyRatio;
        protected boolean removePadding;
        protected int width = 224;
        protected int height = 224;
        protected Image.Flag flag;
        protected Pipeline pipeline;
        protected Batchifier batchifier;

        public Builder optOutputType(YoloOutputType outputType) {
            this.outputType = outputType;
            return this;
        }

        public Builder optNmsThreshold(float nmsThreshold) {
            this.nmsThreshold = nmsThreshold;
            return this;
        }

        public YoloV5FaceTranslator build() {
            if (this.pipeline == null) {
                this.addTransform(array -> array.transpose(new int[]{2, 0, 1}).toType(DataType.FLOAT32, false).div((Number)255));
            }
            return new YoloV5FaceTranslator(this);
        }

        protected Builder self() {
            return this;
        }

        public Builder addTransform(Transform transform) {
            if (this.pipeline == null) {
                this.pipeline = new Pipeline();
            }
            this.pipeline.add(transform);
            return this.self();
        }

        public Builder optApplyRatio(boolean value) {
            this.applyRatio = value;
            return this.self();
        }

        public Builder optFlag(Image.Flag flag) {
            this.flag = flag;
            return this.self();
        }

        public Builder setPipeline(Pipeline pipeline) {
            this.pipeline = pipeline;
            return this.self();
        }

        public Builder setImageSize(int width, int height) {
            this.width = width;
            this.height = height;
            return this.self();
        }

        public Builder optBatchifier(Batchifier batchifier) {
            this.batchifier = batchifier;
            return this.self();
        }

        public Builder optThreshold(float threshold) {
            this.threshold = threshold;
            return this.self();
        }

        protected void configPostProcess(Map<String, ?> arguments) {
            if (ArgumentsUtil.booleanValue(arguments, (String)"optApplyRatio") || ArgumentsUtil.booleanValue(arguments, (String)"applyRatio")) {
                this.optApplyRatio(true);
            }
            this.threshold = ArgumentsUtil.floatValue(arguments, (String)"threshold", (float)0.2f);
            String centerFit = ArgumentsUtil.stringValue(arguments, (String)"centerFit", (String)"false");
            this.removePadding = "true".equals(centerFit);
            String type = ArgumentsUtil.stringValue(arguments, (String)"outputType", (String)"AUTO");
            this.outputType = YoloOutputType.valueOf(type.toUpperCase(Locale.ENGLISH));
            this.nmsThreshold = ArgumentsUtil.floatValue(arguments, (String)"nmsThreshold", (float)0.4f);
            this.maxBox = ArgumentsUtil.intValue(arguments, (String)"maxBox", (int)8400);
        }

        protected void configPreProcess(Map<String, ?> arguments) {
            String normalize;
            int shortEdge;
            int w;
            String pad;
            if (this.pipeline == null) {
                this.pipeline = new Pipeline();
            }
            this.width = ArgumentsUtil.intValue(arguments, (String)"width", (int)224);
            this.height = ArgumentsUtil.intValue(arguments, (String)"height", (int)224);
            if (arguments.containsKey("flag")) {
                this.flag = Image.Flag.valueOf((String)arguments.get("flag").toString());
            }
            if ("true".equals(pad = ArgumentsUtil.stringValue(arguments, (String)"pad", (String)"false"))) {
                this.addTransform((Transform)new Pad(0.0));
            } else if (!"false".equals(pad)) {
                double padding = Double.parseDouble(pad);
                this.addTransform((Transform)new Pad(padding));
            }
            String resize = ArgumentsUtil.stringValue(arguments, (String)"resize", (String)"false");
            if ("true".equals(resize)) {
                this.addTransform((Transform)new Resize(this.width, this.height));
            } else if (!"false".equals(resize)) {
                String[] tokens = resize.split("\\s*,\\s*");
                w = (int)Double.parseDouble(tokens[0]);
                shortEdge = tokens.length > 1 ? (int)Double.parseDouble(tokens[1]) : w;
                Image.Interpolation interpolation = tokens.length > 2 ? Image.Interpolation.valueOf((String)tokens[2]) : Image.Interpolation.BILINEAR;
                this.addTransform((Transform)new Resize(w, shortEdge, interpolation));
            }
            String resizeShort = ArgumentsUtil.stringValue(arguments, (String)"resizeShort", (String)"false");
            if ("true".equals(resizeShort)) {
                w = Math.max(this.width, this.height);
                this.addTransform((Transform)new ResizeShort(w));
            } else if (!"false".equals(resizeShort)) {
                String[] tokens = resizeShort.split("\\s*,\\s*");
                shortEdge = (int)Double.parseDouble(tokens[0]);
                int longEdge = tokens.length > 1 ? (int)Double.parseDouble(tokens[1]) : -1;
                Image.Interpolation interpolation = tokens.length > 2 ? Image.Interpolation.valueOf((String)tokens[2]) : Image.Interpolation.BILINEAR;
                this.addTransform((Transform)new ResizeShort(shortEdge, longEdge, interpolation));
            }
            if (ArgumentsUtil.booleanValue(arguments, (String)"centerCrop", (boolean)false)) {
                this.addTransform((Transform)new CenterCrop(this.width, this.height));
            }
            if (ArgumentsUtil.booleanValue(arguments, (String)"centerFit")) {
                this.addTransform((Transform)new CenterFit(this.width, this.height));
            }
            if (ArgumentsUtil.booleanValue(arguments, (String)"toTensor", (boolean)true)) {
                this.addTransform((Transform)new ToTensor());
            }
            if ("true".equals(normalize = ArgumentsUtil.stringValue(arguments, (String)"normalize", (String)"false"))) {
                float[] MEAN = new float[]{0.485f, 0.456f, 0.406f};
                float[] STD = new float[]{0.229f, 0.224f, 0.225f};
                this.addTransform((Transform)new Normalize(MEAN, STD));
            } else if (!"false".equals(normalize)) {
                String[] tokens = normalize.split("\\s*,\\s*");
                if (tokens.length != 6) {
                    throw new IllegalArgumentException("Invalid normalize value: " + normalize);
                }
                float[] mean = new float[]{Float.parseFloat(tokens[0]), Float.parseFloat(tokens[1]), Float.parseFloat(tokens[2])};
                float[] std = new float[]{Float.parseFloat(tokens[3]), Float.parseFloat(tokens[4]), Float.parseFloat(tokens[5])};
                this.addTransform((Transform)new Normalize(mean, std));
            }
            String range = (String)arguments.get("range");
            if ("0,1".equals(range)) {
                this.addTransform(a -> a.div((Number)Float.valueOf(255.0f)));
            } else if ("-1,1".equals(range)) {
                this.addTransform(a -> a.div((Number)Float.valueOf(128.0f)).sub((Number)1));
            }
            if (arguments.containsKey("batchifier")) {
                this.batchifier = Batchifier.fromString((String)((String)arguments.get("batchifier")));
            }
        }
    }
}

