/*
 * Decompiled with CFR 0.152.
 */
package hex.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.deepwater.DeepWater;
import hex.deepwater.DeepWaterAbstractIntegrationTest;
import hex.deepwater.DeepWaterModel;
import hex.deepwater.DeepWaterParameters;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Paths;
import javax.imageio.ImageIO;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.fvec.Frame;
import water.parser.BufferedString;
import water.util.FileUtils;
import water.util.StringUtils;

public class DeepWaterMXNetIntegrationTest
extends DeepWaterAbstractIntegrationTest {
    static long copy(InputStream var0, OutputStream var1) throws IOException {
        byte[] var2 = new byte[4096];
        long var3 = 0L;
        int var5;
        while ((var5 = var0.read(var2)) != -1) {
            var1.write(var2, 0, var5);
            var3 += (long)var5;
        }
        return var3;
    }

    @Override
    DeepWaterParameters.Backend getBackend() {
        return DeepWaterParameters.Backend.mxnet;
    }

    @BeforeClass
    public static void checkBackend() {
        Assume.assumeTrue((boolean)DeepWater.haveBackend((DeepWaterParameters.Backend)DeepWaterParameters.Backend.mxnet));
    }

    public static String extractFile(String path, String file) throws IOException {
        InputStream in = DeepWaterMXNetIntegrationTest.class.getClassLoader().getResourceAsStream(Paths.get(path, file).toString());
        String target = Paths.get(System.getProperty("java.io.tmpdir"), file).toString();
        FileOutputStream out = new FileOutputStream(target);
        DeepWaterMXNetIntegrationTest.copy(in, out);
        return target;
    }

    @Test
    public void inceptionPredictionMX() throws IOException {
        for (boolean gpu : new boolean[]{true, false}) {
            int w = 224;
            int h = 224;
            int channels = 3;
            int nclasses = 1000;
            ImageDataSet id = new ImageDataSet(w, h, channels, nclasses);
            RuntimeOptions opts = new RuntimeOptions();
            opts.setSeed(1234L);
            opts.setUseGPU(gpu);
            BackendParams bparm = new BackendParams();
            bparm.set("mini_batch_size", (Object)1);
            String path = "deepwater/backends/mxnet/models/Inception/";
            BackendModel _model = this.backend.buildNet(id, opts, bparm, nclasses, StringUtils.expandPath((String)DeepWaterMXNetIntegrationTest.extractFile(path, "Inception_BN-symbol.json")));
            this.backend.loadParam(_model, StringUtils.expandPath((String)DeepWaterMXNetIntegrationTest.extractFile(path, "Inception_BN-0039.params")));
            Frame labels = DeepWaterMXNetIntegrationTest.parse_test_file((String)DeepWaterMXNetIntegrationTest.extractFile(path, "synset.txt"));
            float[] mean = this.backend.loadMeanImage(_model, DeepWaterMXNetIntegrationTest.extractFile(path, "mean_224.nd"));
            File imgFile = FileUtils.getFile((String)"smalldata/deepwater/imagenet/test2.jpg");
            BufferedImage img = ImageIO.read(imgFile);
            BufferedImage scaledImg = new BufferedImage(w, h, img.getType());
            Graphics2D g2d = scaledImg.createGraphics();
            g2d.drawImage(img, 0, 0, w, h, null);
            g2d.dispose();
            float[] pixels = new float[w * h * channels];
            int r_idx = 0;
            int g_idx = r_idx + w * h;
            int b_idx = g_idx + w * h;
            for (int i = 0; i < h; ++i) {
                for (int j = 0; j < w; ++j) {
                    Color mycolor = new Color(scaledImg.getRGB(j, i));
                    int red = mycolor.getRed();
                    int green = mycolor.getGreen();
                    int blue = mycolor.getBlue();
                    pixels[r_idx] = (float)red - mean[r_idx];
                    ++r_idx;
                    pixels[g_idx] = (float)green - mean[g_idx];
                    ++g_idx;
                    pixels[b_idx] = (float)blue - mean[b_idx];
                    ++b_idx;
                }
            }
            float[] preds = this.backend.predict(_model, pixels);
            int K = 5;
            int[] topK = new int[K];
            block3: for (int i = 0; i < preds.length; ++i) {
                for (int j = 0; j < K; ++j) {
                    if (!(preds[i] > preds[topK[j]])) continue;
                    topK[j] = i;
                    continue block3;
                }
            }
            StringBuilder sb = new StringBuilder();
            sb.append("\nTop " + K + " predictions:\n");
            BufferedString str = new BufferedString();
            for (int j = 0; j < K; ++j) {
                String label = labels.anyVec().atStr(str, (long)topK[j]).toString();
                sb.append(" Score: " + String.format("%.4f", Float.valueOf(preds[topK[j]])) + "\t" + label + "\n");
            }
            System.out.println("\n\n" + sb.toString() + "\n\n");
            Assert.assertTrue((String)"Illegal predictions!", (boolean)sb.toString().substring(40, 60).contains("Pembroke"));
            labels.remove();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Ignore
    @Test
    public void PreTrainedMOJO() {
        Frame tr = null;
        Frame preds = null;
        DeepWaterModel m = null;
        try {
            DeepWaterParameters p = new DeepWaterParameters();
            tr = DeepWaterMXNetIntegrationTest.parse_test_file((String)"bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            p._train = tr._key;
            p._response_column = "C2";
            String path = "../deepwater/mxnet/src/main/resources/deepwater/backends/mxnet/models/Inception/";
            p._image_shape = new int[]{224, 224};
            p._channels = 3;
            p._network_definition_file = path + "Inception_BN-symbol.json";
            p._network_parameters_file = path + "Inception_BN-0039.params";
            p._mean_image_file = path + "mean_224.nd";
            p._epochs = 0.1;
            p._learning_rate = 0.0;
            DeepWater j = new DeepWater(p);
            m = (DeepWaterModel)j.trainModel().get();
            preds = m.score((Frame)p._train.get());
            Assert.assertTrue((boolean)m.testJavaScoring((Frame)p._train.get(), preds, 0.001));
        }
        finally {
            if (tr != null) {
                tr.remove();
            }
            if (preds != null) {
                preds.remove();
            }
            if (m != null) {
                m.remove();
            }
        }
    }
}

