/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.cls.translator;

import ai.djl.Model;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.CenterFit;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Pad;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ResizeShort;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Utils;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.List;
import java.util.Map;

public class YoloClsTranslator
implements Translator<Image, Classifications> {
    protected float threshold;
    protected List<String> classes;
    protected boolean applyRatio;
    protected Pipeline pipeline;
    private Image.Flag flag;
    private Batchifier batchifier;
    protected int width;
    protected int height;
    protected int topk;
    private SynsetLoader synsetLoader;

    public void prepare(TranslatorContext ctx) throws IOException {
        if (this.classes == null) {
            this.classes = this.synsetLoader.load(ctx.getModel());
        }
    }

    protected YoloClsTranslator(Builder builder) {
        this.threshold = builder.threshold;
        this.synsetLoader = builder.synsetLoader;
        this.applyRatio = builder.applyRatio;
        this.flag = builder.flag;
        this.pipeline = builder.pipeline;
        this.batchifier = builder.batchifier;
        this.width = builder.width;
        this.height = builder.height;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
        NDManager manager = ctx.getNDManager();
        NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
        array = NDImageUtils.centerCrop((NDArray)array);
        array = NDImageUtils.resize((NDArray)array, (int)this.width, (int)this.height);
        array = array.toType(DataType.FLOAT32, false).div((Number)Float.valueOf(255.0f));
        array = array.transpose(new int[]{2, 0, 1});
        return new NDList(new NDArray[]{array});
    }

    public Classifications processOutput(TranslatorContext ctx, NDList list) throws Exception {
        NDArray probabilitiesNd = list.singletonOrThrow();
        return new Classifications(this.classes, probabilitiesNd, 5);
    }

    protected static final class SynsetLoader {
        private String synsetFileName;
        private URL synsetUrl;
        private List<String> synset;

        public SynsetLoader(List<String> synset) {
            this.synset = synset;
        }

        public SynsetLoader(URL synsetUrl) {
            this.synsetUrl = synsetUrl;
        }

        public SynsetLoader(String synsetFileName) {
            this.synsetFileName = synsetFileName;
        }

        public List<String> load(Model model) throws IOException {
            if (this.synset != null) {
                return this.synset;
            }
            if (this.synsetUrl != null) {
                try (InputStream is = this.synsetUrl.openStream();){
                    List list = Utils.readLines((InputStream)is);
                    return list;
                }
            }
            return (List)model.getArtifact(this.synsetFileName, Utils::readLines);
        }
    }

    public static class Builder {
        protected float threshold = 0.2f;
        protected boolean applyRatio;
        protected boolean removePadding;
        protected int width = 224;
        protected int height = 224;
        protected Image.Flag flag;
        protected Pipeline pipeline;
        protected Batchifier batchifier;
        protected int topk = 5;
        protected SynsetLoader synsetLoader;

        public YoloClsTranslator build() {
            if (this.pipeline == null) {
                this.addTransform(array -> array.transpose(new int[]{2, 0, 1}).toType(DataType.FLOAT32, false).div((Number)255));
            }
            return new YoloClsTranslator(this);
        }

        protected Builder self() {
            return this;
        }

        public Builder addTransform(Transform transform) {
            if (this.pipeline == null) {
                this.pipeline = new Pipeline();
            }
            this.pipeline.add(transform);
            return this.self();
        }

        public Builder optApplyRatio(boolean value) {
            this.applyRatio = value;
            return this.self();
        }

        public Builder optFlag(Image.Flag flag) {
            this.flag = flag;
            return this.self();
        }

        public Builder setPipeline(Pipeline pipeline) {
            this.pipeline = pipeline;
            return this.self();
        }

        public Builder setImageSize(int width, int height) {
            this.width = width;
            this.height = height;
            return this.self();
        }

        public Builder optBatchifier(Batchifier batchifier) {
            this.batchifier = batchifier;
            return this.self();
        }

        public Builder optThreshold(float threshold) {
            this.threshold = threshold;
            return this.self();
        }

        public Builder optTopk(int topk) {
            this.topk = topk;
            return this.self();
        }

        public Builder optSynsetArtifactName(String synsetArtifactName) {
            this.synsetLoader = new SynsetLoader(synsetArtifactName);
            return this.self();
        }

        public Builder optSynsetUrl(String synsetUrl) {
            try {
                this.synsetLoader = new SynsetLoader(new URL(synsetUrl));
            }
            catch (MalformedURLException e) {
                throw new IllegalArgumentException("Invalid synsetUrl: " + synsetUrl, e);
            }
            return this.self();
        }

        public Builder optSynset(List<String> synset) {
            this.synsetLoader = new SynsetLoader(synset);
            return this.self();
        }

        protected void configPostProcess(Map<String, ?> arguments) {
            if (ArgumentsUtil.booleanValue(arguments, (String)"optApplyRatio") || ArgumentsUtil.booleanValue(arguments, (String)"applyRatio")) {
                this.optApplyRatio(true);
            }
            this.threshold = ArgumentsUtil.floatValue(arguments, (String)"threshold", (float)0.2f);
            String centerFit = ArgumentsUtil.stringValue(arguments, (String)"centerFit", (String)"false");
            this.removePadding = "true".equals(centerFit);
            String type = ArgumentsUtil.stringValue(arguments, (String)"outputType", (String)"AUTO");
        }

        protected void configPreProcess(Map<String, ?> arguments) {
            String normalize;
            int shortEdge;
            int w;
            String pad;
            if (this.pipeline == null) {
                this.pipeline = new Pipeline();
            }
            this.width = ArgumentsUtil.intValue(arguments, (String)"width", (int)224);
            this.height = ArgumentsUtil.intValue(arguments, (String)"height", (int)224);
            if (arguments.containsKey("flag")) {
                this.flag = Image.Flag.valueOf((String)arguments.get("flag").toString());
            }
            if ("true".equals(pad = ArgumentsUtil.stringValue(arguments, (String)"pad", (String)"false"))) {
                this.addTransform((Transform)new Pad(0.0));
            } else if (!"false".equals(pad)) {
                double padding = Double.parseDouble(pad);
                this.addTransform((Transform)new Pad(padding));
            }
            String resize = ArgumentsUtil.stringValue(arguments, (String)"resize", (String)"false");
            if ("true".equals(resize)) {
                this.addTransform((Transform)new Resize(this.width, this.height));
            } else if (!"false".equals(resize)) {
                String[] tokens = resize.split("\\s*,\\s*");
                w = (int)Double.parseDouble(tokens[0]);
                shortEdge = tokens.length > 1 ? (int)Double.parseDouble(tokens[1]) : w;
                Image.Interpolation interpolation = tokens.length > 2 ? Image.Interpolation.valueOf((String)tokens[2]) : Image.Interpolation.BILINEAR;
                this.addTransform((Transform)new Resize(w, shortEdge, interpolation));
            }
            String resizeShort = ArgumentsUtil.stringValue(arguments, (String)"resizeShort", (String)"false");
            if ("true".equals(resizeShort)) {
                w = Math.max(this.width, this.height);
                this.addTransform((Transform)new ResizeShort(w));
            } else if (!"false".equals(resizeShort)) {
                String[] tokens = resizeShort.split("\\s*,\\s*");
                shortEdge = (int)Double.parseDouble(tokens[0]);
                int longEdge = tokens.length > 1 ? (int)Double.parseDouble(tokens[1]) : -1;
                Image.Interpolation interpolation = tokens.length > 2 ? Image.Interpolation.valueOf((String)tokens[2]) : Image.Interpolation.BILINEAR;
                this.addTransform((Transform)new ResizeShort(shortEdge, longEdge, interpolation));
            }
            if (ArgumentsUtil.booleanValue(arguments, (String)"centerCrop", (boolean)false)) {
                this.addTransform((Transform)new CenterCrop(this.width, this.height));
            }
            if (ArgumentsUtil.booleanValue(arguments, (String)"centerFit")) {
                this.addTransform((Transform)new CenterFit(this.width, this.height));
            }
            if (ArgumentsUtil.booleanValue(arguments, (String)"toTensor", (boolean)true)) {
                this.addTransform((Transform)new ToTensor());
            }
            if ("true".equals(normalize = ArgumentsUtil.stringValue(arguments, (String)"normalize", (String)"false"))) {
                float[] MEAN = new float[]{0.485f, 0.456f, 0.406f};
                float[] STD = new float[]{0.229f, 0.224f, 0.225f};
                this.addTransform((Transform)new Normalize(MEAN, STD));
            } else if (!"false".equals(normalize)) {
                String[] tokens = normalize.split("\\s*,\\s*");
                if (tokens.length != 6) {
                    throw new IllegalArgumentException("Invalid normalize value: " + normalize);
                }
                float[] mean = new float[]{Float.parseFloat(tokens[0]), Float.parseFloat(tokens[1]), Float.parseFloat(tokens[2])};
                float[] std = new float[]{Float.parseFloat(tokens[3]), Float.parseFloat(tokens[4]), Float.parseFloat(tokens[5])};
                this.addTransform((Transform)new Normalize(mean, std));
            }
            String range = (String)arguments.get("range");
            if ("0,1".equals(range)) {
                this.addTransform(a -> a.div((Number)Float.valueOf(255.0f)));
            } else if ("-1,1".equals(range)) {
                this.addTransform(a -> a.div((Number)Float.valueOf(128.0f)).sub((Number)1));
            }
            if (arguments.containsKey("batchifier")) {
                this.batchifier = Batchifier.fromString((String)((String)arguments.get("batchifier")));
            }
        }
    }
}

