package hex.genmodel.algos.deepwater.caffe;

import com.google.protobuf.nano.CodedInputByteBufferNano;
import com.google.protobuf.nano.CodedOutputByteBufferNano;
import deepwater.backends.BackendModel;
import hex.genmodel.algos.deepwater.caffe.nano.Deepwater;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.ProcessBuilder;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

/* loaded from: input_file:hex/genmodel/algos/deepwater/caffe/DeepwaterCaffeModel.class */
public class DeepwaterCaffeModel implements BackendModel {
    private int[] _input_shape;
    private int[] _sizes;
    private String[] _types;
    private double[] _dropout_ratios;
    private long _seed;
    private boolean _useGPU;
    private String _graph;
    private Process _process;
    private static final ThreadLocal<ByteBuffer> _buffer = new ThreadLocal<>();

    public DeepwaterCaffeModel(int i, int[] iArr, String[] strArr, double[] dArr, long j, boolean z) {
        this._input_shape = new int[0];
        this._sizes = new int[0];
        this._types = new String[0];
        this._dropout_ratios = new double[0];
        this._graph = "";
        this._input_shape = new int[]{i, 1, 1, iArr[0]};
        this._sizes = iArr;
        this._types = strArr;
        this._dropout_ratios = dArr;
        this._seed = j;
        this._useGPU = z;
        start();
    }

    public DeepwaterCaffeModel(String str, int[] iArr, long j, boolean z) {
        this._input_shape = new int[0];
        this._sizes = new int[0];
        this._types = new String[0];
        this._dropout_ratios = new double[0];
        this._graph = "";
        this._graph = str;
        this._input_shape = iArr;
        this._seed = j;
        this._useGPU = z;
        start();
    }

    private void start() {
        if (this._process == null) {
            try {
                startRegular();
                Deepwater.Cmd cmd = new Deepwater.Cmd();
                cmd.type = 0;
                cmd.graph = this._graph;
                cmd.inputShape = this._input_shape;
                cmd.solverType = "Adam";
                cmd.sizes = this._sizes;
                cmd.types = this._types;
                cmd.dropoutRatios = this._dropout_ratios;
                cmd.learningRate = 0.01f;
                cmd.momentum = 0.99f;
                cmd.randomSeed = this._seed;
                cmd.useGpu = this._useGPU;
                call(cmd);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public void saveModel(String str) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 3;
        cmd.path = str;
        call(cmd);
    }

    public void saveParam(String str) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 4;
        cmd.path = str;
        call(cmd);
    }

    public void loadParam(String str) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 5;
        cmd.path = str;
        call(cmd);
    }

    private static void copy(float[] fArr, byte[] bArr) {
        if (fArr.length * 4 != bArr.length) {
            throw new RuntimeException();
        }
        ByteBuffer byteBuffer = _buffer.get();
        if (byteBuffer == null || byteBuffer.capacity() < bArr.length) {
            ThreadLocal<ByteBuffer> threadLocal = _buffer;
            ByteBuffer allocateDirect = ByteBuffer.allocateDirect(bArr.length);
            byteBuffer = allocateDirect;
            threadLocal.set(allocateDirect);
            byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
        }
        byteBuffer.clear();
        byteBuffer.asFloatBuffer().put(fArr);
        byteBuffer.get(bArr);
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [byte[], byte[][]] */
    private static void copy(float[][] fArr, Deepwater.Cmd cmd) {
        cmd.data = new byte[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            cmd.data[i] = new byte[fArr[i].length * 4];
            copy(fArr[i], cmd.data[i]);
        }
    }

    /* JADX WARN: Type inference failed for: r0v14, types: [float[], float[][]] */
    public void train(float[] fArr, float[] fArr2) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 1;
        cmd.inputShape = this._input_shape;
        if (fArr.length != this._input_shape[0] * this._input_shape[1] * this._input_shape[2] * this._input_shape[3]) {
            throw new RuntimeException();
        }
        if (fArr2.length != this._input_shape[0]) {
            throw new RuntimeException();
        }
        copy((float[][]) new float[]{fArr, fArr2}, cmd);
        call(cmd);
    }

    /* JADX WARN: Type inference failed for: r0v4, types: [float[], float[][]] */
    public float[] predict(float[] fArr) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 2;
        cmd.inputShape = this._input_shape;
        copy((float[][]) new float[]{fArr}, cmd);
        Deepwater.Cmd call = call(cmd);
        ByteBuffer byteBuffer = _buffer.get();
        if (byteBuffer == null || byteBuffer.capacity() < call.data[0].length) {
            ThreadLocal<ByteBuffer> threadLocal = _buffer;
            ByteBuffer allocateDirect = ByteBuffer.allocateDirect(call.data[0].length);
            byteBuffer = allocateDirect;
            threadLocal.set(allocateDirect);
            byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
        }
        byteBuffer.clear();
        byteBuffer.put(call.data[0]);
        float[] fArr2 = new float[call.data[0].length / 4];
        byteBuffer.flip();
        byteBuffer.asFloatBuffer().get(fArr2);
        return fArr2;
    }

    private void startRegular() throws IOException {
        ProcessBuilder processBuilder = new ProcessBuilder("python3 backend.py".split(" "));
        processBuilder.environment().put("PYTHONPATH", "/opt/caffe/python");
        processBuilder.redirectError(ProcessBuilder.Redirect.INHERIT);
        processBuilder.directory(new File(DeepwaterCaffeBackend.CAFFE_H2O_DIR));
        this._process = processBuilder.start();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void close() {
        this._process.destroy();
        try {
            this._process.waitFor();
        } catch (InterruptedException e) {
        }
    }

    private Deepwater.Cmd call(Deepwater.Cmd cmd) {
        try {
            OutputStream outputStream = this._process.getOutputStream();
            int serializedSize = cmd.getSerializedSize();
            ByteBuffer allocate = ByteBuffer.allocate(4 + serializedSize);
            allocate.putInt(serializedSize);
            cmd.writeTo(CodedOutputByteBufferNano.newInstance(allocate.array(), allocate.position(), allocate.remaining()));
            allocate.position(allocate.position() + serializedSize);
            outputStream.write(allocate.array(), 0, allocate.position());
            outputStream.flush();
            InputStream inputStream = this._process.getInputStream();
            int read = inputStream.read(allocate.array(), 0, 4);
            if (read != 4) {
                throw new RuntimeException();
            }
            allocate.position(0);
            allocate.limit(read);
            int i = allocate.getInt();
            if (allocate.capacity() < i) {
                allocate = ByteBuffer.allocate(i);
            }
            allocate.position(0);
            allocate.limit(i);
            while (allocate.position() < allocate.limit()) {
                allocate.position(allocate.position() + inputStream.read(allocate.array(), allocate.position(), allocate.limit()));
            }
            Deepwater.Cmd cmd2 = new Deepwater.Cmd();
            cmd2.m8mergeFrom(CodedInputByteBufferNano.newInstance(allocate.array(), 0, allocate.position()));
            return cmd2;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
