package hex.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.deepwater.DeepWaterParameters;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Paths;
import javax.imageio.ImageIO;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.fvec.Frame;
import water.parser.BufferedString;
import water.util.FileUtils;
import water.util.StringUtils;

/* loaded from: input_file:hex/deepwater/DeepWaterMXNetIntegrationTest.class */
public class DeepWaterMXNetIntegrationTest extends DeepWaterAbstractIntegrationTest {
    static long copy(InputStream inputStream, OutputStream outputStream) throws IOException {
        byte[] bArr = new byte[4096];
        long j = 0;
        while (true) {
            long j2 = j;
            int read = inputStream.read(bArr);
            if (read == -1) {
                return j2;
            }
            outputStream.write(bArr, 0, read);
            j = j2 + read;
        }
    }

    @Override // hex.deepwater.DeepWaterAbstractIntegrationTest
    DeepWaterParameters.Backend getBackend() {
        return DeepWaterParameters.Backend.mxnet;
    }

    @BeforeClass
    public static void checkBackend() {
        Assume.assumeTrue(DeepWater.haveBackend(DeepWaterParameters.Backend.mxnet));
    }

    public static String extractFile(String str, String str2) throws IOException {
        InputStream resourceAsStream = DeepWaterMXNetIntegrationTest.class.getClassLoader().getResourceAsStream(Paths.get(str, str2).toString());
        String path = Paths.get(System.getProperty("java.io.tmpdir"), str2).toString();
        copy(resourceAsStream, new FileOutputStream(path));
        return path;
    }

    @Test
    public void inceptionPredictionMX() throws IOException {
        for (boolean z : new boolean[]{true, false}) {
            ImageDataSet imageDataSet = new ImageDataSet(224, 224, 3, 1000);
            RuntimeOptions runtimeOptions = new RuntimeOptions();
            runtimeOptions.setSeed(1234L);
            runtimeOptions.setUseGPU(z);
            BackendParams backendParams = new BackendParams();
            backendParams.set("mini_batch_size", 1);
            BackendModel buildNet = this.backend.buildNet(imageDataSet, runtimeOptions, backendParams, 1000, StringUtils.expandPath(extractFile("deepwater/backends/mxnet/models/Inception/", "Inception_BN-symbol.json")));
            this.backend.loadParam(buildNet, StringUtils.expandPath(extractFile("deepwater/backends/mxnet/models/Inception/", "Inception_BN-0039.params")));
            Frame parse_test_file = parse_test_file(extractFile("deepwater/backends/mxnet/models/Inception/", "synset.txt"));
            float[] loadMeanImage = this.backend.loadMeanImage(buildNet, extractFile("deepwater/backends/mxnet/models/Inception/", "mean_224.nd"));
            BufferedImage read = ImageIO.read(FileUtils.getFile("smalldata/deepwater/imagenet/test2.jpg"));
            BufferedImage bufferedImage = new BufferedImage(224, 224, read.getType());
            Graphics2D createGraphics = bufferedImage.createGraphics();
            createGraphics.drawImage(read, 0, 0, 224, 224, (ImageObserver) null);
            createGraphics.dispose();
            float[] fArr = new float[224 * 224 * 3];
            int i = 0;
            int i2 = 0 + (224 * 224);
            int i3 = i2 + (224 * 224);
            for (int i4 = 0; i4 < 224; i4++) {
                for (int i5 = 0; i5 < 224; i5++) {
                    Color color = new Color(bufferedImage.getRGB(i5, i4));
                    int red = color.getRed();
                    int green = color.getGreen();
                    int blue = color.getBlue();
                    fArr[i] = red - loadMeanImage[i];
                    i++;
                    fArr[i2] = green - loadMeanImage[i2];
                    i2++;
                    fArr[i3] = blue - loadMeanImage[i3];
                    i3++;
                }
            }
            float[] predict = this.backend.predict(buildNet, fArr);
            int[] iArr = new int[5];
            for (int i6 = 0; i6 < predict.length; i6++) {
                int i7 = 0;
                while (true) {
                    if (i7 >= 5) {
                        break;
                    }
                    if (predict[i6] > predict[iArr[i7]]) {
                        iArr[i7] = i6;
                        break;
                    }
                    i7++;
                }
            }
            StringBuilder sb = new StringBuilder();
            sb.append("\nTop 5 predictions:\n");
            BufferedString bufferedString = new BufferedString();
            for (int i8 = 0; i8 < 5; i8++) {
                sb.append(" Score: " + String.format("%.4f", Float.valueOf(predict[iArr[i8]])) + "\t" + parse_test_file.anyVec().atStr(bufferedString, iArr[i8]).toString() + "\n");
            }
            System.out.println("\n\n" + sb.toString() + "\n\n");
            Assert.assertTrue("Illegal predictions!", sb.toString().substring(40, 60).contains("Pembroke"));
            parse_test_file.remove();
        }
    }

    @Test
    @Ignore
    public void PreTrainedMOJO() {
        Frame frame = null;
        Frame frame2 = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._image_shape = new int[]{224, 224};
            deepWaterParameters._channels = 3;
            deepWaterParameters._network_definition_file = "../deepwater/mxnet/src/main/resources/deepwater/backends/mxnet/models/Inception/Inception_BN-symbol.json";
            deepWaterParameters._network_parameters_file = "../deepwater/mxnet/src/main/resources/deepwater/backends/mxnet/models/Inception/Inception_BN-0039.params";
            deepWaterParameters._mean_image_file = "../deepwater/mxnet/src/main/resources/deepwater/backends/mxnet/models/Inception/mean_224.nd";
            deepWaterParameters._epochs = 0.1d;
            deepWaterParameters._learning_rate = 0.0d;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            frame2 = deepWaterModel.score(deepWaterParameters._train.get());
            Assert.assertTrue(deepWaterModel.testJavaScoring(deepWaterParameters._train.get(), frame2, 0.001d));
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            throw th;
        }
    }
}
