/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facedect.mtcnn;

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslateException;
import cn.smartjavaai.common.utils.NMSUtils;
import cn.smartjavaai.face.model.facedect.mtcnn.MtcnnUtils;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RNetModel {
    private static final Logger log = LoggerFactory.getLogger(RNetModel.class);

    public static NDList secondStage(NDManager manager, Predictor<NDList, NDList> rnetPredictor, NDArray imgs, NDArray boxes, NDList pad, NDArray image_inds) throws TranslateException {
        NDArray y = (NDArray)pad.get(0);
        NDArray ey = (NDArray)pad.get(1);
        NDArray x = (NDArray)pad.get(2);
        NDArray ex = (NDArray)pad.get(3);
        ArrayList<NDArray> crops = new ArrayList<NDArray>();
        long numFaces = y.size(0);
        for (long k = 0L; k < numFaces; ++k) {
            if (ey.getInt(new long[]{k}) <= y.getInt(new long[]{k}) - 1 || ex.getInt(new long[]{k}) <= x.getInt(new long[]{k}) - 1) continue;
            NDArray imgK = imgs.get(image_inds.getLong(new long[]{k}) + ", :, " + (y.getInt(new long[]{k}) - 1) + ":" + ey.getInt(new long[]{k}) + ", " + (x.getInt(new long[]{k}) - 1) + ":" + ex.getInt(new long[]{k}), new Object[0]).expandDims(0);
            NDArray transposed = imgK.transpose(new int[]{0, 2, 3, 1});
            transposed = NDImageUtils.resize((NDArray)transposed, (int)24, (int)24, (Image.Interpolation)Image.Interpolation.AREA);
            transposed = transposed.transpose(new int[]{0, 3, 1, 2});
            crops.add(transposed);
        }
        if (crops.isEmpty()) {
            log.debug("No face detected.");
            return null;
        }
        NDArray im_data = NDArrays.concat((NDList)new NDList(crops), (int)0);
        im_data = im_data.sub((Number)127.5).mul((Number)0.0078125);
        NDList out = (NDList)rnetPredictor.predict((Object)new NDList(new NDArray[]{im_data}));
        NDArray out0 = ((NDArray)out.get(0)).transpose(new int[]{1, 0});
        NDArray out1 = ((NDArray)out.get(1)).transpose(new int[]{1, 0});
        NDArray score = out1.get(new long[]{1L});
        NDArray ipass = score.gt((Number)0.7);
        long[] validIndices = ipass.nonzero().toLongArray();
        NDArray boxesSelected = boxes.get(manager.create(validIndices));
        boxesSelected = boxesSelected.get(":, 0:4", new Object[0]);
        NDArray scoresFiltered = score.get(ipass).reshape(new long[]{-1L, 1L});
        boxes = NDArrays.concat((NDList)new NDList(new NDArray[]{boxesSelected, scoresFiltered}), (int)1);
        NDArray image_indsFiltered = image_inds.get(ipass);
        NDArray mv = out0.transpose().get(ipass);
        NDArray pick = NMSUtils.batchedNms((NDArray)boxes.get(":, :4", new Object[0]), (NDArray)boxes.get(":, 4", new Object[0]), (NDArray)image_indsFiltered, (float)0.7f, (NDManager)manager);
        boxes = boxes.get(pick);
        image_indsFiltered = image_indsFiltered.get(pick);
        mv = mv.get(pick);
        boxes = MtcnnUtils.bbreg(boxes, mv);
        if ((boxes = MtcnnUtils.rerec(boxes)).size(0) == 0L) {
            log.debug("No face detected.");
            return null;
        }
        return new NDList(new NDArray[]{image_indsFiltered, scoresFiltered, boxes});
    }
}

