package hex.deepwater;

import hex.deepwater.DeepWaterParameters;
import hex.genmodel.algos.deepwater.caffe.DeepwaterCaffeModel;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;

/* loaded from: input_file:hex/deepwater/DeepWaterCaffeIntegrationTest.class */
public class DeepWaterCaffeIntegrationTest extends DeepWaterAbstractIntegrationTest {
    @Override // hex.deepwater.DeepWaterAbstractIntegrationTest
    DeepWaterParameters.Backend getBackend() {
        return DeepWaterParameters.Backend.caffe;
    }

    @BeforeClass
    public static void checkBackend() {
        Assume.assumeTrue(DeepWater.haveBackend(DeepWaterParameters.Backend.caffe));
    }

    @Test
    @Ignore
    public void run() throws Exception {
        String property = System.getProperty("user.home");
        DataInputStream dataInputStream = new DataInputStream(new GZIPInputStream(new FileInputStream(new File(property + "/train-images-idx3-ubyte.gz"))));
        DataInputStream dataInputStream2 = new DataInputStream(new GZIPInputStream(new FileInputStream(new File(property + "/train-labels-idx1-ubyte.gz"))));
        dataInputStream.readInt();
        int readInt = dataInputStream.readInt();
        dataInputStream.readInt();
        dataInputStream.readInt();
        dataInputStream2.readInt();
        dataInputStream2.readInt();
        System.out.println("Read " + readInt + " samples");
        byte[][] bArr = new byte[readInt][784];
        byte[] bArr2 = new byte[readInt];
        for (int i = 0; i < readInt; i++) {
            dataInputStream.readFully(bArr[i]);
            bArr2[i] = dataInputStream2.readByte();
        }
        System.out.println("Randomize");
        Random random = new Random();
        for (int i2 = readInt - 1; i2 >= 0; i2--) {
            int nextInt = random.nextInt(i2 + 1);
            byte[] bArr3 = bArr[nextInt];
            bArr[nextInt] = bArr[i2];
            bArr[i2] = bArr3;
            byte b = bArr2[nextInt];
            bArr2[nextInt] = bArr2[i2];
            bArr2[i2] = b;
        }
        System.out.println("Create model");
        DeepwaterCaffeModel deepwaterCaffeModel = new DeepwaterCaffeModel(256, new int[]{784, 4024, 4024, 4048, 10}, new String[]{"data", "relu", "relu", "relu", "loss"}, new double[]{0.9d, 0.5d, 0.5d, 0.5d, 0.0d}, 1234L, true);
        System.out.println("Train");
        float[] fArr = new float[200704];
        float[] fArr2 = new float[256];
        for (int i3 = 0; i3 < 10; i3++) {
            for (int i4 = 0; i4 < 256; i4++) {
                for (int i5 = 0; i5 < 784; i5++) {
                    fArr[(i4 * 784) + i5] = (bArr[i4][i5] & 255) * 0.00390625f;
                }
                fArr2[i4] = bArr2[i4];
            }
            deepwaterCaffeModel.train(fArr, fArr2);
            deepwaterCaffeModel.predict(fArr);
        }
        deepwaterCaffeModel.saveModel("/tmp/graph");
        deepwaterCaffeModel.saveParam("/tmp/params");
        deepwaterCaffeModel.loadParam("/tmp/params");
    }
}
