/*
 * Decompiled with CFR 0.152.
 */
package hex.deepwater;

import hex.deepwater.DeepWater;
import hex.deepwater.DeepWaterAbstractIntegrationTest;
import hex.deepwater.DeepWaterParameters;
import hex.genmodel.algos.deepwater.caffe.DeepwaterCaffeModel;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;

public class DeepWaterCaffeIntegrationTest
extends DeepWaterAbstractIntegrationTest {
    @Override
    DeepWaterParameters.Backend getBackend() {
        return DeepWaterParameters.Backend.caffe;
    }

    @BeforeClass
    public static void checkBackend() {
        Assume.assumeTrue((boolean)DeepWater.haveBackend((DeepWaterParameters.Backend)DeepWaterParameters.Backend.caffe));
    }

    @Ignore
    @Test
    public void run() throws Exception {
        int PIXELS = 784;
        String home = System.getProperty("user.home");
        DataInputStream pixels = new DataInputStream(new GZIPInputStream(new FileInputStream(new File(home + "/train-images-idx3-ubyte.gz"))));
        DataInputStream labels = new DataInputStream(new GZIPInputStream(new FileInputStream(new File(home + "/train-labels-idx1-ubyte.gz"))));
        pixels.readInt();
        int count = pixels.readInt();
        pixels.readInt();
        pixels.readInt();
        labels.readInt();
        labels.readInt();
        System.out.println("Read " + count + " samples");
        byte[][] rawI = new byte[count][784];
        byte[] rawL = new byte[count];
        for (int i = 0; i < count; ++i) {
            pixels.readFully(rawI[i]);
            rawL[i] = labels.readByte();
        }
        System.out.println("Randomize");
        Random rand = new Random();
        for (int i = count - 1; i >= 0; --i) {
            int shuffle = rand.nextInt(i + 1);
            byte[] image = rawI[shuffle];
            rawI[shuffle] = rawI[i];
            rawI[i] = image;
            byte label = rawL[shuffle];
            rawL[shuffle] = rawL[i];
            rawL[i] = label;
        }
        System.out.println("Create model");
        int batch = 256;
        DeepwaterCaffeModel model = new DeepwaterCaffeModel(256, new int[]{784, 4024, 4024, 4048, 10}, new String[]{"data", "relu", "relu", "relu", "loss"}, new double[]{0.9, 0.5, 0.5, 0.5, 0.0}, 1234L, true);
        System.out.println("Train");
        float[] ps = new float[200704];
        float[] ls = new float[256];
        for (int iter = 0; iter < 10; ++iter) {
            for (int b = 0; b < 256; ++b) {
                for (int i = 0; i < 784; ++i) {
                    ps[b * 784 + i] = (float)(rawI[b][i] & 0xFF) * 0.00390625f;
                }
                ls[b] = rawL[b];
            }
            model.train(ps, ls);
            model.predict(ps);
        }
        model.saveModel("/tmp/graph");
        model.saveParam("/tmp/params");
        model.loadParam("/tmp/params");
    }
}

