/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facedect.mtcnn;

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import cn.smartjavaai.common.utils.NMSUtils;
import cn.smartjavaai.face.model.facedect.mtcnn.MtcnnProcess;
import cn.smartjavaai.face.model.facedect.mtcnn.MtcnnUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class PNetModel {
    public static List<Double> generateScales(Image image) {
        long h = image.getHeight();
        long w = image.getWidth();
        double minsize = 20.0;
        double m = 12.0 / minsize;
        double factor = 0.709;
        ArrayList<Double> scales = new ArrayList<Double>();
        double scale_i = m;
        for (double minl = (double)Math.min(h, w) * m; minl >= 12.0; minl *= factor) {
            scales.add(scale_i);
            scale_i *= factor;
        }
        return scales;
    }

    public static NDArray pNetPre(Image input, NDManager manager) {
        NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
        array = array.expandDims(0);
        if (!(array = array.transpose(new int[]{0, 3, 1, 2})).getDataType().equals((Object)DataType.FLOAT32)) {
            array = array.toType(DataType.FLOAT32, false);
        }
        return array;
    }

    public static NDList firstStage(NDManager manager, Predictor<NDList, NDList> pnetPredictor, Image image) throws TranslateException {
        List<Double> scales = MtcnnProcess.generateScales(image);
        NDArray imgs = PNetModel.pNetPre(image, manager);
        int h = image.getHeight();
        int w = image.getWidth();
        NDList boxes_list = new NDList();
        NDList image_inds_list = new NDList();
        NDList scale_picks_list = new NDList();
        int offset = 0;
        for (double scale : scales) {
            int newH = (int)((double)h * scale + 1.0);
            int newW = (int)((double)w * scale + 1.0);
            NDArray transposed = imgs.transpose(new int[]{0, 2, 3, 1});
            transposed = NDImageUtils.resize((NDArray)transposed, (int)newW, (int)newH, (Image.Interpolation)Image.Interpolation.AREA);
            transposed = transposed.transpose(new int[]{0, 3, 1, 2});
            transposed = transposed.sub((Number)127.5).mul((Number)Float.valueOf(0.0078125f));
            NDList outputPnet = (NDList)pnetPredictor.predict((Object)new NDList(new NDArray[]{transposed}));
            NDArray reg = (NDArray)outputPnet.get(0);
            NDArray probs = (NDArray)outputPnet.get(1);
            List<NDArray> boundingBox = PNetModel.generateBoundingBox(reg, probs.get(":, 1", new Object[0]), (float)scale, 0.6f);
            NDArray boxes_scale = boundingBox.get(0);
            NDArray imgIndND = boundingBox.get(1);
            NDArray pick = NMSUtils.batchedNms((NDArray)boxes_scale.get(":,:4", new Object[0]), (NDArray)boxes_scale.get(":,4", new Object[0]), (NDArray)imgIndND, (float)0.5f, (NDManager)manager);
            boxes_list.add((Object)boxes_scale);
            image_inds_list.add((Object)imgIndND);
            scale_picks_list.add((Object)pick.add((Number)offset));
            offset = (int)((long)offset + boxes_scale.getShape().get(0));
        }
        NDArray boxes = NDArrays.concat((NDList)boxes_list, (int)0);
        NDArray image_inds = NDArrays.concat((NDList)image_inds_list, (int)0);
        NDArray scale_picks = NDArrays.concat((NDList)scale_picks_list, (int)0);
        boxes = boxes.get(scale_picks);
        image_inds = image_inds.get(scale_picks);
        NDArray pick = NMSUtils.batchedNms((NDArray)boxes.get(":, :4", new Object[0]), (NDArray)boxes.get(":, 4", new Object[0]), (NDArray)image_inds, (float)0.7f, (NDManager)manager);
        boxes = boxes.get(pick);
        image_inds = image_inds.get(pick);
        NDArray regw = boxes.get(":, 2", new Object[0]).sub(boxes.get(":, 0", new Object[0]));
        NDArray regh = boxes.get(":, 3", new Object[0]).sub(boxes.get(":, 1", new Object[0]));
        NDArray qq1 = boxes.get(":, 0", new Object[0]).add(boxes.get(":, 5", new Object[0]).mul(regw));
        NDArray qq2 = boxes.get(":, 1", new Object[0]).add(boxes.get(":, 6", new Object[0]).mul(regh));
        NDArray qq3 = boxes.get(":, 2", new Object[0]).add(boxes.get(":, 7", new Object[0]).mul(regw));
        NDArray qq4 = boxes.get(":, 3", new Object[0]).add(boxes.get(":, 8", new Object[0]).mul(regh));
        boxes = NDArrays.stack((NDList)new NDList(new NDArray[]{qq1, qq2, qq3, qq4, boxes.get(":, 4", new Object[0])}), (int)1);
        boxes = MtcnnUtils.rerec(boxes);
        return new NDList(new NDArray[]{boxes, image_inds, imgs});
    }

    public static List<NDArray> generateBoundingBox(NDArray reg, NDArray probs, float scale, float threshold) {
        float stride = 2.0f;
        float cellSize = 12.0f;
        NDArray mask = probs.gte((Number)Float.valueOf(threshold));
        NDArray maskInds = mask.nonzero();
        NDArray imageInds = maskInds.get(":,0", new Object[0]);
        NDArray yx = maskInds.get(":,1:", new Object[0]);
        NDArray bb = yx.flip(new int[]{1});
        NDArray q1 = bb.mul((Number)Float.valueOf(stride)).add((Number)1).div((Number)Float.valueOf(scale)).floor();
        NDArray q2 = bb.mul((Number)Float.valueOf(stride)).add((Number)Float.valueOf(cellSize)).div((Number)Float.valueOf(scale)).floor();
        Shape probShape = probs.getShape();
        long H = probShape.get(1);
        long W = probShape.get(2);
        NDArray linearIndex = maskInds.get(":,0", new Object[0]).mul((Number)(H * W)).add(maskInds.get(":,1", new Object[0]).mul((Number)W)).add(maskInds.get(":,2", new Object[0]));
        NDArray scores = probs.reshape(new long[]{-1L}).gather(linearIndex, 0);
        NDArray regPerm = reg.transpose(new int[]{1, 0, 2, 3});
        NDArray regFlat = regPerm.reshape(new long[]{4L, -1L});
        NDArray linearIndexForGather = linearIndex.expandDims(0).repeat(0, 4L);
        NDArray regPicked = regFlat.gather(linearIndexForGather, 1).transpose();
        NDArray x1 = q1.get(":, 0", new Object[0]).expandDims(1);
        NDArray y1 = q1.get(":, 1", new Object[0]).expandDims(1);
        NDArray x2 = q2.get(":, 0", new Object[0]).expandDims(1);
        NDArray y2 = q2.get(":, 1", new Object[0]).expandDims(1);
        NDArray boundingBoxes = NDArrays.concat((NDList)new NDList(new NDArray[]{x1, y1, x2, y2, scores.expandDims(1), regPicked}), (int)1);
        return Arrays.asList(boundingBoxes, imageInds);
    }
}

