/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facedect.mtcnn;

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import cn.smartjavaai.common.utils.NMSUtils;
import cn.smartjavaai.face.model.facedect.mtcnn.MtcnnBatchResult;
import cn.smartjavaai.face.model.facedect.mtcnn.MtcnnUtils;
import java.util.ArrayList;
import java.util.List;

public class ONetModel {
    public static MtcnnBatchResult thirdStage(NDManager manager, Predictor<NDList, NDList> onetPredictor, NDArray imgs, NDArray boxes, int w, int h, NDArray scoresFiltered, NDArray image_indsFiltered) throws TranslateException {
        NDArray points = manager.zeros(new Shape(new long[]{0L, 5L, 2L}));
        NDList pad = MtcnnUtils.pad(boxes, w, h);
        NDArray y = (NDArray)pad.get(0);
        NDArray ey = (NDArray)pad.get(1);
        NDArray x = (NDArray)pad.get(2);
        NDArray ex = (NDArray)pad.get(3);
        ArrayList<NDArray> crops = new ArrayList<NDArray>();
        long numFaces = y.size(0);
        for (long k = 0L; k < numFaces; ++k) {
            if (ey.getInt(new long[]{k}) <= y.getInt(new long[]{k}) - 1 || ex.getInt(new long[]{k}) <= x.getInt(new long[]{k}) - 1) continue;
            NDArray imgK = imgs.get(image_indsFiltered.getLong(new long[]{k}) + ", :, " + (y.getInt(new long[]{k}) - 1) + ":" + ey.getInt(new long[]{k}) + ", " + (x.getInt(new long[]{k}) - 1) + ":" + ex.getInt(new long[]{k}), new Object[0]).expandDims(0);
            NDArray transposed = imgK.transpose(new int[]{0, 2, 3, 1});
            transposed = NDImageUtils.resize((NDArray)transposed, (int)48, (int)48, (Image.Interpolation)Image.Interpolation.AREA);
            transposed = transposed.transpose(new int[]{0, 3, 1, 2});
            crops.add(transposed);
        }
        NDArray im_data = NDArrays.concat((NDList)new NDList(crops), (int)0);
        im_data = im_data.sub((Number)127.5).mul((Number)0.0078125);
        NDList out = (NDList)onetPredictor.predict((Object)new NDList(new NDArray[]{im_data}));
        NDArray out0 = ((NDArray)out.get(0)).transpose(new int[]{1, 0});
        NDArray out1 = ((NDArray)out.get(1)).transpose(new int[]{1, 0});
        NDArray out2 = ((NDArray)out.get(2)).transpose(new int[]{1, 0});
        NDArray score = out2.get(new long[]{1L});
        points = out1.duplicate();
        NDArray ipass = score.gt((Number)0.7);
        NDArray ipassBool = ipass.toType(DataType.BOOLEAN, false);
        long[] colIdx = ipassBool.nonzero().toLongArray();
        NDArray moved = points.swapAxes(0, 1);
        NDArray selected = moved.get(points.getManager().create(colIdx));
        points = selected.swapAxes(0, 1);
        long[] validIndices = ipass.nonzero().toLongArray();
        NDArray boxesSelected = boxes.get(manager.create(validIndices));
        boxesSelected = boxesSelected.get(":, 0:4", new Object[0]);
        scoresFiltered = score.get(ipass).reshape(new long[]{-1L, 1L});
        boxes = NDArrays.concat((NDList)new NDList(new NDArray[]{boxesSelected, scoresFiltered}), (int)1);
        image_indsFiltered = image_indsFiltered.get(ipass);
        NDArray mv = out0.transpose().get(ipass);
        NDArray w_i = boxes.get(":,2", new Object[0]).sub(boxes.get(":,0", new Object[0])).add((Number)1);
        NDArray h_i = boxes.get(":,3", new Object[0]).sub(boxes.get(":,1", new Object[0])).add((Number)1);
        NDArray w_repeat = w_i.expandDims(0).repeat(0, 5L);
        NDArray p_x = points.get("0:5,:", new Object[0]).mul(w_repeat).add(boxes.get(":,0", new Object[0]).expandDims(0).repeat(0, 5L)).sub((Number)1);
        NDArray h_repeat = h_i.expandDims(0).repeat(0, 5L);
        NDArray p_y = points.get("5:10,:", new Object[0]).mul(h_repeat).add(boxes.get(":,1", new Object[0]).expandDims(0).repeat(0, 5L)).sub((Number)1);
        NDArray pointsStacked = NDArrays.stack((NDList)new NDList(new NDArray[]{p_x, p_y}));
        points = pointsStacked.transpose(new int[]{2, 1, 0});
        boxes = MtcnnUtils.bbreg(boxes, mv);
        NDArray pick = NMSUtils.batchedNms((NDArray)boxes.get(":, :4", new Object[0]), (NDArray)boxes.get(":, 4", new Object[0]), (NDArray)image_indsFiltered, (float)0.7f, (NDManager)manager);
        boxes = boxes.get(pick);
        image_indsFiltered = image_indsFiltered.get(pick);
        points = points.get(pick);
        ArrayList<NDArray> batchBoxes = new ArrayList<NDArray>();
        ArrayList<NDArray> batchPoints = new ArrayList<NDArray>();
        for (int b_i = 0; b_i < 1; ++b_i) {
            NDArray mask = image_indsFiltered.eq((Number)b_i);
            NDArray batchBox = boxes.get(mask);
            NDArray batchPoint = points.get(mask);
            batchBoxes.add(batchBox);
            batchPoints.add(batchPoint);
        }
        return ONetModel.processBatchBoxes(batchBoxes, batchPoints, true, manager);
    }

    public static MtcnnBatchResult processBatchBoxes(List<NDArray> batchBoxes, List<NDArray> batchPoints, boolean selectLargest, NDManager manager) {
        ArrayList<NDArray> boxesOut = new ArrayList<NDArray>();
        ArrayList<NDArray> probsOut = new ArrayList<NDArray>();
        ArrayList<NDArray> pointsOut = new ArrayList<NDArray>();
        for (int i = 0; i < batchBoxes.size(); ++i) {
            NDArray pointsSelected;
            NDArray boxesSelected;
            NDArray box = batchBoxes.get(i);
            NDArray point = batchPoints.get(i);
            if (box == null || box.isEmpty()) {
                boxesOut.add(null);
                probsOut.add(null);
                pointsOut.add(null);
                continue;
            }
            if (selectLargest) {
                NDArray w = box.get(":,2", new Object[0]).sub(box.get(":,0", new Object[0]));
                NDArray h = box.get(":,3", new Object[0]).sub(box.get(":,1", new Object[0]));
                NDArray areas = w.mul(h);
                NDArray order = areas.argSort().flip(new int[]{0});
                boxesSelected = box.get(order);
                pointsSelected = point.get(order);
            } else {
                boxesSelected = box;
                pointsSelected = point;
            }
            boxesSelected = boxesSelected.get(":,0:4", new Object[0]);
            NDArray probsSelected = box.get(":,4", new Object[0]);
            boxesOut.add(boxesSelected);
            probsOut.add(probsSelected);
            pointsOut.add(pointsSelected);
        }
        MtcnnBatchResult result = new MtcnnBatchResult(boxesOut, probsOut, pointsOut);
        result.boxes = boxesOut;
        result.probs = probsOut;
        result.points = pointsOut;
        return result;
    }
}

