package ai.kognition.pilecv4j.tf;

import ai.kognition.pilecv4j.image.CvMat;
import ai.kognition.pilecv4j.image.CvRaster;
import com.google.protobuf.InvalidProtocolBufferException;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.util.stream.LongStream;
import net.dempsy.util.QuietCloseable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.Graph;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:ai/kognition/pilecv4j/tf/TensorUtils.class */
public class TensorUtils {
    public static final Logger LOGGER = LoggerFactory.getLogger(TensorUtils.class);

    public static Tensor toTensor(CvRaster cvRaster, Class<? extends TType> cls) {
        Shape of = Shape.of(new long[]{1, cvRaster.rows(), cvRaster.cols(), cvRaster.channels()});
        ByteBuffer underlying = cvRaster.underlying();
        underlying.rewind();
        QuietCloseable quietCloseable = () -> {
            underlying.rewind();
        };
        try {
            TType of2 = Tensor.of(cls, of, DataBuffers.of(underlying));
            if (quietCloseable != null) {
                quietCloseable.close();
            }
            return of2;
        } catch (Throwable th) {
            if (quietCloseable != null) {
                try {
                    quietCloseable.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static Tensor toTensor(CvMat cvMat, Class<? extends TType> cls) {
        return (Tensor) cvMat.rasterOp(cvRaster -> {
            return toTensor(cvRaster, (Class<? extends TType>) cls);
        });
    }

    public static Graph inflate(byte[] bArr) throws InvalidProtocolBufferException {
        Graph graph = new Graph();
        graph.importGraphDef(GraphDef.parseFrom(bArr));
        return graph;
    }

    public static float getScalar(Tensor tensor) {
        return ((TFloat32) tensor).getFloat(new long[0]);
    }

    public static float[] getVector(Tensor tensor) {
        int i = (int) tensor.shape().asArray()[1];
        float[][] fArr = new float[1][i];
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= fArr.length) {
                return fArr[0];
            }
            long j3 = 0;
            while (true) {
                long j4 = j3;
                if (j4 < i) {
                    fArr[(int) j2][(int) j4] = ((TFloat32) tensor).getFloat(new long[]{j2, j4});
                    j3 = j4 + 1;
                }
            }
            j = j2 + 1;
        }
    }

    public static float[][] getMatrix(Tensor tensor) {
        float[][][] fArr = (float[][][]) Array.newInstance((Class<?>) Float.TYPE, LongStream.of(tensor.shape().asArray()).mapToInt(j -> {
            return (int) j;
        }).toArray());
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= r0[0]) {
                return fArr[0];
            }
            long j4 = 0;
            while (true) {
                long j5 = j4;
                if (j5 < r0[1]) {
                    long j6 = 0;
                    while (true) {
                        long j7 = j6;
                        if (j7 < r0[2]) {
                            fArr[(int) j3][(int) j5][(int) j7] = ((TFloat32) tensor).getFloat(new long[]{j3, j5, j7});
                            j6 = j7 + 1;
                        }
                    }
                    j4 = j5 + 1;
                }
            }
            j2 = j3 + 1;
        }
    }
}
