package hex.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.FrameSplitter;
import hex.Model;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.deepwater.DeepWaterParameters;
import hex.genmodel.algos.deepwater.DeepwaterMojoModel;
import hex.splitframe.ShuffleSplitFrame;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.AutoBuffer;
import water.DKV;
import water.IcedUtils;
import water.Job;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.Vec;
import water.parser.ParseDataset;
import water.util.FileUtils;
import water.util.Log;
import water.util.StringUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/deepwater/DeepWaterAbstractIntegrationTest.class */
public abstract class DeepWaterAbstractIntegrationTest extends TestUtil {
    protected BackendTrain backend;

    abstract DeepWaterParameters.Backend getBackend();

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @BeforeClass
    public static void checkBackend() {
        Assume.assumeTrue(DeepWater.haveBackend());
    }

    @Before
    public void createBackend() throws Exception {
        this.backend = DeepwaterMojoModel.createDeepWaterBackend(getBackend().toString());
        Assert.assertTrue(this.backend != null);
    }

    @Test
    public void memoryLeakTest() {
        DeepWaterModel deepWaterModel = null;
        Frame frame = null;
        int i = 3;
        while (true) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                return;
            }
            try {
                DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
                deepWaterParameters._backend = getBackend();
                Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
                frame = parse_test_file;
                deepWaterParameters._train = parse_test_file._key;
                deepWaterParameters._response_column = "C2";
                deepWaterParameters._network = DeepWaterParameters.Network.vgg;
                deepWaterParameters._learning_rate = 1.0E-4d;
                deepWaterParameters._mini_batch_size = 8;
                deepWaterParameters._train_samples_per_iteration = 8L;
                deepWaterParameters._epochs = 0.001d;
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Log.info(new Object[]{deepWaterModel});
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }
    }

    void trainSamplesPerIteration(int i, int i2) {
        DeepWaterModel deepWaterModel = null;
        Frame frame = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._learning_rate = 0.001d;
            deepWaterParameters._epochs = 3.0d;
            deepWaterParameters._train_samples_per_iteration = i;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Assert.assertEquals(i2, deepWaterModel.iterations);
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    @Test
    public void trainSamplesPerIteration0() {
        trainSamplesPerIteration(0, 3);
    }

    @Test
    public void trainSamplesPerIteration_auto() {
        trainSamplesPerIteration(-2, 1);
    }

    @Test
    public void trainSamplesPerIteration_neg1() {
        trainSamplesPerIteration(-1, 3);
    }

    @Test
    public void trainSamplesPerIteration_32() {
        trainSamplesPerIteration(32, 26);
    }

    @Test
    public void trainSamplesPerIteration_1000() {
        trainSamplesPerIteration(1000, 1);
    }

    @Test
    public void overWriteWithBestModel() {
        DeepWaterModel deepWaterModel = null;
        Frame frame = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._epochs = 50.0d;
            deepWaterParameters._learning_rate = 0.01d;
            deepWaterParameters._momentum_start = 0.5d;
            deepWaterParameters._momentum_stable = 0.5d;
            deepWaterParameters._stopping_rounds = 0;
            deepWaterParameters._image_shape = new int[]{28, 28};
            deepWaterParameters._network = DeepWaterParameters.Network.lenet;
            deepWaterParameters._problem_type = DeepWaterParameters.ProblemType.image;
            deepWaterParameters._train_samples_per_iteration = deepWaterParameters._mini_batch_size;
            deepWaterParameters._score_duty_cycle = 1.0d;
            deepWaterParameters._score_interval = 0.0d;
            deepWaterParameters._overwrite_with_best_model = true;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Log.info(new Object[]{deepWaterModel});
            Assert.assertTrue(deepWaterModel._output._training_metrics.logloss() < 2.0d);
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    void checkConvergence(int i, DeepWaterParameters.Network network, int i2) {
        DeepWaterModel deepWaterModel = null;
        Frame frame = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._network = network;
            deepWaterParameters._learning_rate = 0.001d;
            deepWaterParameters._epochs = i2;
            deepWaterParameters._channels = i;
            if (network == DeepWaterParameters.Network.vgg) {
                deepWaterParameters._mini_batch_size = 8;
            } else if (network == DeepWaterParameters.Network.resnet) {
                deepWaterParameters._mini_batch_size = 16;
                deepWaterParameters._learning_rate = 1.0E-4d;
            } else if (network == DeepWaterParameters.Network.alexnet) {
                deepWaterParameters._mini_batch_size = 128;
                deepWaterParameters._learning_rate = 1.0E-4d;
            } else {
                deepWaterParameters._mini_batch_size = 32;
            }
            deepWaterParameters._problem_type = DeepWaterParameters.ProblemType.image;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Log.info(new Object[]{deepWaterModel});
            System.out.println("Accuracy " + deepWaterModel._output._training_metrics.cm().accuracy());
            Assert.assertTrue(deepWaterModel._output._training_metrics.cm().accuracy() > 0.9d);
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    @Test
    public void convergenceInceptionColor() {
        checkConvergence(3, DeepWaterParameters.Network.inception_bn, 150);
    }

    @Test
    public void convergenceInceptionGrayScale() {
        checkConvergence(1, DeepWaterParameters.Network.inception_bn, 150);
    }

    @Test
    public void convergenceGoogleNetColor() {
        checkConvergence(3, DeepWaterParameters.Network.googlenet, 150);
    }

    @Test
    public void convergenceGoogleNetGrayScale() {
        checkConvergence(1, DeepWaterParameters.Network.googlenet, 150);
    }

    @Test
    public void convergenceLenetColor() {
        checkConvergence(3, DeepWaterParameters.Network.lenet, 300);
    }

    @Test
    public void convergenceLenetGrayScale() {
        checkConvergence(1, DeepWaterParameters.Network.lenet, 150);
    }

    @Test
    public void convergenceVGGColor() {
        checkConvergence(3, DeepWaterParameters.Network.vgg, 150);
    }

    @Test
    public void convergenceVGGGrayScale() {
        checkConvergence(1, DeepWaterParameters.Network.vgg, 150);
    }

    @Test
    public void convergenceResnetColor() {
        checkConvergence(3, DeepWaterParameters.Network.resnet, 150);
    }

    @Test
    public void convergenceResnetGrayScale() {
        checkConvergence(1, DeepWaterParameters.Network.resnet, 150);
    }

    @Test
    public void convergenceAlexnetColor() {
        checkConvergence(3, DeepWaterParameters.Network.alexnet, 150);
    }

    @Test
    public void convergenceAlexnetGrayScale() {
        checkConvergence(1, DeepWaterParameters.Network.alexnet, 150);
    }

    /* JADX WARN: Finally extract failed */
    @Test
    @Ignore
    public void reproInitialDistribution() {
        double[] dArr = new double[3];
        for (int i = 0; i < 3; i++) {
            DeepWaterModel deepWaterModel = null;
            Frame frame = null;
            try {
                DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
                deepWaterParameters._backend = getBackend();
                Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
                frame = parse_test_file;
                deepWaterParameters._train = parse_test_file._key;
                deepWaterParameters._response_column = "C2";
                deepWaterParameters._learning_rate = 0.0d;
                deepWaterParameters._seed = 1234L;
                deepWaterParameters._epochs = 1.0d;
                deepWaterParameters._channels = 1;
                deepWaterParameters._train_samples_per_iteration = 0L;
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Log.info(new Object[]{deepWaterModel});
                dArr[i] = deepWaterModel._output._training_metrics.logloss();
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }
        for (int i2 = 1; i2 < 3; i2++) {
            Assert.assertEquals(dArr[0], dArr[i2], 1.0E-5d * dArr[0]);
        }
    }

    /* JADX WARN: Finally extract failed */
    @Test
    public void reproInitialDistributionNegativeTest() {
        double[] dArr = new double[3];
        for (int i = 0; i < 3; i++) {
            DeepWaterModel deepWaterModel = null;
            Frame frame = null;
            try {
                DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
                deepWaterParameters._backend = getBackend();
                Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
                frame = parse_test_file;
                deepWaterParameters._train = parse_test_file._key;
                deepWaterParameters._response_column = "C2";
                deepWaterParameters._learning_rate = 0.0d;
                deepWaterParameters._seed = i;
                deepWaterParameters._epochs = 1.0d;
                deepWaterParameters._channels = 1;
                deepWaterParameters._train_samples_per_iteration = 0L;
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Log.info(new Object[]{deepWaterModel});
                dArr[i] = deepWaterModel._output._training_metrics.logloss();
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }
        for (int i2 = 1; i2 < 3; i2++) {
            Assert.assertNotEquals(dArr[0], dArr[i2], 1.0E-5d * dArr[0]);
        }
    }

    @Test
    @Ignore
    public void settingModelInfoAll() {
        for (DeepWaterParameters.Network network : DeepWaterParameters.Network.values()) {
            if (network != DeepWaterParameters.Network.user && network != DeepWaterParameters.Network.auto) {
                settingModelInfo(network);
            }
        }
    }

    @Test
    public void settingModelInfoAlexnet() {
        settingModelInfo(DeepWaterParameters.Network.alexnet);
    }

    @Test
    public void settingModelInfoLenet() {
        settingModelInfo(DeepWaterParameters.Network.lenet);
    }

    @Test
    public void settingModelInfoVGG() {
        settingModelInfo(DeepWaterParameters.Network.vgg);
    }

    @Test
    public void settingModelInfoInception() {
        settingModelInfo(DeepWaterParameters.Network.inception_bn);
    }

    @Test
    public void settingModelInfoResnet() {
        settingModelInfo(DeepWaterParameters.Network.resnet);
    }

    void settingModelInfo(DeepWaterParameters.Network network) {
        DeepWaterModel deepWaterModel = null;
        DeepWaterModel deepWaterModel2 = null;
        Frame frame = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._network = network;
            deepWaterParameters._mini_batch_size = 2;
            deepWaterParameters._epochs = 0.01d;
            deepWaterParameters._seed = 1234L;
            deepWaterParameters._score_training_samples = 0L;
            deepWaterParameters._train_samples_per_iteration = deepWaterParameters._mini_batch_size;
            deepWaterParameters._problem_type = DeepWaterParameters.ProblemType.image;
            Job trainModel = new DeepWater(deepWaterParameters).trainModel();
            deepWaterModel = (DeepWaterModel) trainModel.get();
            int hashCode = Arrays.hashCode(deepWaterModel.model_info()._modelparams);
            deepWaterModel.doScoring(frame, (Frame) null, trainModel._key, deepWaterModel.iterations, true);
            double loss = deepWaterModel.loss();
            deepWaterParameters._seed = 4321L;
            Job trainModel2 = new DeepWater(deepWaterParameters).trainModel();
            deepWaterModel2 = (DeepWaterModel) trainModel2.get();
            int hashCode2 = Arrays.hashCode(deepWaterModel2.model_info()._modelparams);
            deepWaterModel2.removeNativeState();
            deepWaterModel2.set_model_info(IcedUtils.deepCopy(deepWaterModel.model_info()));
            deepWaterModel2.doScoring(frame, (Frame) null, trainModel2._key, deepWaterModel2.iterations, true);
            double loss2 = deepWaterModel2.loss();
            int hashCode3 = Arrays.hashCode(deepWaterModel2.model_info()._modelparams);
            Log.info(new Object[]{"Checking assertions for network: " + network});
            Assert.assertNotEquals(hashCode, hashCode2);
            Assert.assertEquals(hashCode, hashCode3);
            Assert.assertEquals(loss, loss2, 1.0E-5d * loss);
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    /* JADX WARN: Finally extract failed */
    @Test
    @Ignore
    public void reproTraining() {
        double[] dArr = new double[3];
        for (int i = 0; i < 3; i++) {
            DeepWaterModel deepWaterModel = null;
            Frame frame = null;
            try {
                DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
                deepWaterParameters._backend = getBackend();
                Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
                frame = parse_test_file;
                deepWaterParameters._train = parse_test_file._key;
                deepWaterParameters._response_column = "C2";
                deepWaterParameters._learning_rate = 1.0E-4d;
                deepWaterParameters._seed = 1234L;
                deepWaterParameters._epochs = 1.0d;
                deepWaterParameters._channels = 1;
                deepWaterParameters._train_samples_per_iteration = 0L;
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Log.info(new Object[]{deepWaterModel});
                dArr[i] = deepWaterModel._output._training_metrics.logloss();
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }
        for (int i2 = 1; i2 < 3; i2++) {
            Assert.assertEquals(dArr[0], dArr[i2], 1.0E-5d * dArr[0]);
        }
    }

    @Test
    @Ignore
    public void deepWaterLoadSaveTestAll() {
        for (DeepWaterParameters.Network network : DeepWaterParameters.Network.values()) {
            if (network != DeepWaterParameters.Network.auto && network != DeepWaterParameters.Network.user) {
                deepWaterLoadSaveTest(network);
            }
        }
    }

    @Test
    public void deepWaterLoadSaveTestAlexnet() {
        deepWaterLoadSaveTest(DeepWaterParameters.Network.alexnet);
    }

    @Test
    public void deepWaterLoadSaveTestLenet() {
        deepWaterLoadSaveTest(DeepWaterParameters.Network.lenet);
    }

    @Test
    public void deepWaterLoadSaveTestVGG() {
        deepWaterLoadSaveTest(DeepWaterParameters.Network.vgg);
    }

    @Test
    public void deepWaterLoadSaveTestInception() {
        deepWaterLoadSaveTest(DeepWaterParameters.Network.inception_bn);
    }

    @Test
    public void deepWaterLoadSaveTestResnet() {
        deepWaterLoadSaveTest(DeepWaterParameters.Network.resnet);
    }

    void deepWaterLoadSaveTest(DeepWaterParameters.Network network) {
        DeepWaterModel deepWaterModel = null;
        Frame frame = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._network = network;
            deepWaterParameters._mini_batch_size = 2;
            deepWaterParameters._epochs = 0.01d;
            deepWaterParameters._seed = 1234L;
            deepWaterParameters._score_training_samples = 0L;
            deepWaterParameters._train_samples_per_iteration = deepWaterParameters._mini_batch_size;
            deepWaterParameters._problem_type = DeepWaterParameters.ProblemType.image;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Log.info(new Object[]{deepWaterModel});
            Assert.assertTrue(deepWaterModel.model_info()._backend == null);
            int hashCode = Arrays.hashCode(deepWaterModel.model_info()._network);
            int hashCode2 = Arrays.hashCode(deepWaterModel.model_info()._modelparams);
            Log.info(new Object[]{"Hash code for original network: " + hashCode});
            Log.info(new Object[]{"Hash code for original parameters: " + hashCode2});
            deepWaterModel.removeNativeState();
            deepWaterModel.model_info().javaToNative();
            deepWaterModel.model_info().nativeToJava();
            int hashCode3 = Arrays.hashCode(deepWaterModel.model_info()._network);
            int hashCode4 = Arrays.hashCode(deepWaterModel.model_info()._modelparams);
            Log.info(new Object[]{"Hash code for restored network: " + hashCode3});
            Log.info(new Object[]{"Hash code for restored parameters: " + hashCode4});
            Assert.assertEquals(hashCode, hashCode3);
            Assert.assertEquals(hashCode2, hashCode4);
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    @Test
    public void deepWaterCV() {
        DeepWaterModel deepWaterModel = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._network = DeepWaterParameters.Network.lenet;
            deepWaterParameters._nfolds = 3;
            deepWaterParameters._epochs = 2.0d;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            frame2 = deepWaterModel.score(deepWaterParameters._train.get());
            Assert.assertTrue(deepWaterModel.testJavaScoring(deepWaterParameters._train.get(), frame2, 0.001d));
            Log.info(new Object[]{deepWaterModel});
            if (deepWaterModel != null) {
                deepWaterModel.deleteCrossValidationModels();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.deleteCrossValidationModels();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    public void deepWaterCVRegression() {
        DeepWaterModel deepWaterModel = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            for (String str : new String[]{deepWaterParameters._response_column}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toNumericVec());
                remove.remove();
            }
            DKV.put(frame);
            deepWaterParameters._network = DeepWaterParameters.Network.lenet;
            deepWaterParameters._nfolds = 3;
            deepWaterParameters._epochs = 2.0d;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            frame2 = deepWaterModel.score(deepWaterParameters._train.get());
            Assert.assertTrue(deepWaterModel.testJavaScoring(deepWaterParameters._train.get(), frame2, 0.001d));
            Log.info(new Object[]{deepWaterModel});
            if (deepWaterModel != null) {
                deepWaterModel.deleteCrossValidationModels();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.deleteCrossValidationModels();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    @Ignore
    public void restoreStateAll() {
        for (DeepWaterParameters.Network network : DeepWaterParameters.Network.values()) {
            if (network != DeepWaterParameters.Network.user && network != DeepWaterParameters.Network.auto) {
                restoreState(network);
            }
        }
    }

    @Test
    public void restoreStateAlexnet() {
        restoreState(DeepWaterParameters.Network.alexnet);
    }

    @Test
    public void restoreStateLenet() {
        restoreState(DeepWaterParameters.Network.lenet);
    }

    @Test
    public void restoreStateVGG() {
        restoreState(DeepWaterParameters.Network.vgg);
    }

    @Test
    public void restoreStateInception() {
        restoreState(DeepWaterParameters.Network.inception_bn);
    }

    @Test
    public void restoreStateResnet() {
        restoreState(DeepWaterParameters.Network.resnet);
    }

    public void restoreState(DeepWaterParameters.Network network) {
        DeepWaterModel deepWaterModel = null;
        DeepWaterModel deepWaterModel2 = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._network = network;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._mini_batch_size = 2;
            deepWaterParameters._train_samples_per_iteration = deepWaterParameters._mini_batch_size;
            deepWaterParameters._learning_rate = 0.0d;
            deepWaterParameters._seed = 12345L;
            deepWaterParameters._epochs = 0.01d;
            deepWaterParameters._quiet_mode = true;
            deepWaterParameters._problem_type = DeepWaterParameters.ProblemType.image;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Log.info(new Object[]{"Scoring the original model."});
            Frame score = deepWaterModel.score(frame);
            score.remove(0).remove();
            ModelMetricsMultinomial make = ModelMetricsMultinomial.make(score, frame.vec(deepWaterParameters._response_column));
            Log.info(new Object[]{"Original LL: " + deepWaterModel._output._training_metrics.logloss()});
            Log.info(new Object[]{"Scored   LL: " + make.logloss()});
            score.remove();
            Log.info(new Object[]{"Keeping the raw byte[] of the model."});
            byte[] buf = new AutoBuffer().put(deepWaterModel).buf();
            Log.info(new Object[]{"Removing the model from the DKV."});
            deepWaterModel.remove();
            Log.info(new Object[]{"Restoring the model from the raw byte[]."});
            deepWaterModel2 = (DeepWaterModel) new AutoBuffer(buf).get();
            Log.info(new Object[]{"Scoring the restored model."});
            frame2 = deepWaterModel2.score(frame);
            frame2.remove(0).remove();
            ModelMetricsMultinomial make2 = ModelMetricsMultinomial.make(frame2, frame.vec(deepWaterParameters._response_column));
            Log.info(new Object[]{"Restored LL: " + make2.logloss()});
            Assert.assertEquals(deepWaterModel._output._training_metrics.logloss(), make.logloss(), 1.0E-5d * make.logloss());
            Assert.assertEquals(make.logloss(), make2.logloss(), 1.0E-5d * make.logloss());
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    public void trainLoop() throws InterruptedException {
        BackendModel buildLENET = buildLENET();
        float[] fArr = new float[784 * 64];
        float[] fArr2 = new float[64];
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= 1000) {
                return;
            }
            Log.info(new Object[]{"Iteration: " + i});
            this.backend.train(buildLENET, fArr, fArr2);
        }
    }

    private BackendModel buildLENET() {
        ImageDataSet imageDataSet = new ImageDataSet(28, 28, 1, 10);
        RuntimeOptions runtimeOptions = new RuntimeOptions();
        runtimeOptions.setUseGPU(true);
        runtimeOptions.setSeed(1234L);
        runtimeOptions.setDeviceID(new int[]{0});
        BackendParams backendParams = new BackendParams();
        backendParams.set("mini_batch_size", 64);
        return this.backend.buildNet(imageDataSet, runtimeOptions, backendParams, 10, "lenet");
    }

    @Test
    public void saveLoop() throws IOException {
        BackendModel buildLENET = buildLENET();
        File createTempFile = File.createTempFile("saveLoop", ".tmp");
        for (int i = 0; i < 3; i++) {
            Log.info(new Object[]{"Iteration: " + i});
            this.backend.saveParam(buildLENET, createTempFile.getAbsolutePath());
        }
        this.backend.deleteSavedParam(createTempFile.getAbsolutePath());
    }

    @Test
    public void predictLoop() {
        BackendModel buildLENET = buildLENET();
        float[] fArr = new float[784 * 64];
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= 3) {
                return;
            }
            Log.info(new Object[]{"Iteration: " + i});
            this.backend.predict(buildLENET, fArr);
        }
    }

    @Test
    public void trainPredictLoop() {
        BackendModel buildLENET = buildLENET();
        float[] fArr = new float[784 * 64];
        float[] fArr2 = new float[64];
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= 1000) {
                return;
            }
            Log.info(new Object[]{"Iteration: " + i});
            this.backend.train(buildLENET, fArr, fArr2);
            this.backend.predict(buildLENET, fArr);
        }
    }

    @Test
    public void scoreLoop() {
        DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
        deepWaterParameters._backend = getBackend();
        Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
        deepWaterParameters._train = parse_test_file._key;
        deepWaterParameters._network = DeepWaterParameters.Network.lenet;
        deepWaterParameters._response_column = "C2";
        deepWaterParameters._mini_batch_size = 4;
        deepWaterParameters._train_samples_per_iteration = deepWaterParameters._mini_batch_size;
        deepWaterParameters._learning_rate = 0.0d;
        deepWaterParameters._seed = 12345L;
        deepWaterParameters._epochs = 0.01d;
        deepWaterParameters._quiet_mode = true;
        DeepWater deepWater = new DeepWater(deepWaterParameters);
        DeepWaterModel deepWaterModel = deepWater.trainModel().get();
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= 100) {
                parse_test_file.remove();
                deepWaterModel.remove();
                return;
            } else {
                Log.info(new Object[]{"Iteration: " + i});
                deepWaterModel.doScoring(parse_test_file, (Frame) null, deepWater._job._key, deepWaterModel.iterations, true);
            }
        }
    }

    @Test
    public void prostateClassification() {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("smalldata/prostate/prostate.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "CAPSULE";
            deepWaterParameters._ignored_columns = new String[]{"ID"};
            for (String str : new String[]{"RACE", "DPROS", "DCAPS", "CAPSULE", "GLEASON"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            deepWaterParameters._seed = 1234L;
            deepWaterParameters._epochs = 500.0d;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Assert.assertTrue(deepWaterModel._output._training_metrics.auc_obj()._auc > 0.9d);
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void prostateRegression() {
        Frame frame = null;
        Frame frame2 = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("smalldata/prostate/prostate.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "AGE";
            deepWaterParameters._ignored_columns = new String[]{"ID"};
            for (String str : new String[]{"RACE", "DPROS", "DCAPS", "CAPSULE", "GLEASON"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            deepWaterParameters._seed = 1234L;
            deepWaterParameters._epochs = 1000.0d;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Assert.assertTrue(deepWaterModel._output._training_metrics.rmse() < 5.0d);
            frame2 = deepWaterModel.score(deepWaterParameters._train.get());
            Assert.assertTrue(deepWaterModel.testJavaScoring(deepWaterParameters._train.get(), frame2, 0.001d));
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    public void imageURLs() {
        Frame frame = null;
        Frame frame2 = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("smalldata/deepwater/imagenet/binomial_image_urls.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._network = DeepWaterParameters.Network.lenet;
            deepWaterParameters._epochs = 500.0d;
            deepWaterParameters._seed = 1234L;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Assert.assertTrue(deepWaterModel._output._training_metrics.auc_obj()._auc > 0.85d);
            frame2 = deepWaterModel.score(deepWaterParameters._train.get());
            Assert.assertTrue(deepWaterModel.testJavaScoring(deepWaterParameters._train.get(), frame2, 0.001d, 1.0E-5d, 1.0d));
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void categorical() {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("smalldata/gbm_test/alphabet_cattest.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "y";
            for (String str : new String[]{"y"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            Assert.assertTrue(deepWaterModel._output._training_metrics.auc_obj()._auc > 0.9d);
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void MNISTLenet() {
        Frame frame = null;
        Frame frame2 = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            File locateFile = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
            File locateFile2 = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
            if (locateFile != null) {
                deepWaterParameters._response_column = "C785";
                frame = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(locateFile)._key});
                frame2 = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(locateFile2)._key});
                for (String str : new String[]{deepWaterParameters._response_column}) {
                    Vec remove = frame.remove(str);
                    frame.add(str, remove.toCategoricalVec());
                    remove.remove();
                    Vec remove2 = frame2.remove(str);
                    frame2.add(str, remove2.toCategoricalVec());
                    remove2.remove();
                }
                DKV.put(frame);
                DKV.put(frame2);
                deepWaterParameters._backend = getBackend();
                deepWaterParameters._train = frame._key;
                deepWaterParameters._valid = frame2._key;
                deepWaterParameters._image_shape = new int[]{28, 28};
                deepWaterParameters._ignore_const_cols = false;
                deepWaterParameters._channels = 1;
                deepWaterParameters._network = DeepWaterParameters.Network.lenet;
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Assert.assertTrue(deepWaterModel._output._validation_metrics.mean_per_class_error() < 0.05d);
            }
        } finally {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
        }
    }

    @Test
    public void MNISTSparse() {
        Frame frame = null;
        Frame frame2 = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            File locateFile = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
            File locateFile2 = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
            if (locateFile != null) {
                deepWaterParameters._response_column = "C785";
                frame = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(locateFile)._key});
                frame2 = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(locateFile2)._key});
                for (String str : new String[]{deepWaterParameters._response_column}) {
                    Vec remove = frame.remove(str);
                    frame.add(str, remove.toCategoricalVec());
                    remove.remove();
                    Vec remove2 = frame2.remove(str);
                    frame2.add(str, remove2.toCategoricalVec());
                    remove2.remove();
                }
                DKV.put(frame);
                DKV.put(frame2);
                deepWaterParameters._backend = getBackend();
                deepWaterParameters._train = frame._key;
                deepWaterParameters._valid = frame2._key;
                deepWaterParameters._learning_rate = 0.005d;
                deepWaterParameters._hidden = new int[]{500, 500};
                deepWaterParameters._sparse = true;
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Assert.assertTrue(deepWaterModel._output._validation_metrics.mean_per_class_error() < 0.05d);
            }
        } finally {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
        }
    }

    @Test
    public void MNISTHinton() {
        Frame frame = null;
        Frame frame2 = null;
        DeepWaterModel deepWaterModel = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            File locateFile = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz");
            File locateFile2 = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz");
            if (locateFile != null) {
                deepWaterParameters._response_column = "C785";
                frame = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(locateFile)._key});
                frame2 = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(locateFile2)._key});
                for (String str : new String[]{deepWaterParameters._response_column}) {
                    Vec remove = frame.remove(str);
                    frame.add(str, remove.toCategoricalVec());
                    remove.remove();
                    Vec remove2 = frame2.remove(str);
                    frame2.add(str, remove2.toCategoricalVec());
                    remove2.remove();
                }
                DKV.put(frame);
                DKV.put(frame2);
                deepWaterParameters._backend = getBackend();
                deepWaterParameters._hidden = new int[]{1024, 1024, 2048};
                deepWaterParameters._input_dropout_ratio = 0.1d;
                deepWaterParameters._hidden_dropout_ratios = new double[]{0.5d, 0.5d, 0.5d};
                deepWaterParameters._stopping_rounds = 0;
                deepWaterParameters._learning_rate = 0.001d;
                deepWaterParameters._mini_batch_size = 32;
                deepWaterParameters._epochs = 20.0d;
                deepWaterParameters._train = frame._key;
                deepWaterParameters._valid = frame2._key;
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Assert.assertTrue(deepWaterModel._output._validation_metrics.mean_per_class_error() < 0.05d);
            }
        } finally {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
        }
    }

    @Test
    public void Airlines() {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        Frame[] frameArr = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            File locateFile = FileUtils.locateFile("smalldata/airlines/allyears2k_headers.zip");
            if (locateFile != null) {
                deepWaterParameters._response_column = "IsDepDelayed";
                deepWaterParameters._ignored_columns = new String[]{"DepTime", "ArrTime", "Cancelled", "CancellationCode", "Diverted", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"};
                frame = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(locateFile)._key});
                for (String str : new String[]{deepWaterParameters._response_column, "UniqueCarrier", "Origin", "Dest"}) {
                    Vec remove = frame.remove(str);
                    frame.add(str, remove.toCategoricalVec());
                    remove.remove();
                }
                DKV.put(frame);
                double[] ard = ard(new double[]{0.5d, 0.5d});
                Key[] keyArr = (Key[]) aro(new Key[]{Key.make("test.hex"), Key.make("train.hex")});
                frameArr = ShuffleSplitFrame.shuffleSplitFrame(frame, keyArr, ard, 42L);
                deepWaterParameters._backend = getBackend();
                deepWaterParameters._train = keyArr[0];
                deepWaterParameters._valid = keyArr[1];
                deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
                Assert.assertTrue(deepWaterModel._output._validation_metrics.auc() > 0.65d);
            }
        } finally {
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frameArr != null) {
                for (Frame frame2 : frameArr) {
                    frame2.remove();
                }
            }
        }
    }

    private void MOJOTestImage(DeepWaterParameters.Network network) {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        Frame frame2 = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            Frame parse_test_file = parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv");
            frame = parse_test_file;
            deepWaterParameters._train = parse_test_file._key;
            deepWaterParameters._response_column = "C2";
            deepWaterParameters._learning_rate = 1.0E-4d;
            deepWaterParameters._network = network;
            deepWaterParameters._mini_batch_size = 4;
            deepWaterParameters._train_samples_per_iteration = 8L;
            deepWaterParameters._epochs = 0.001d;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            frame2 = deepWaterModel.score(frame);
            Assert.assertTrue(deepWaterModel.testJavaScoring(frame, frame2, 0.001d));
            frame2.remove(0).remove();
            Assert.assertTrue(Math.abs(ModelMetricsMultinomial.make(frame2, frame.vec(deepWaterParameters._response_column)).logloss() - deepWaterModel._output._training_metrics.logloss()) < 0.001d);
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    public void MOJOTestImageLenet() {
        MOJOTestImage(DeepWaterParameters.Network.lenet);
    }

    @Test
    public void MOJOTestImageInception() {
        MOJOTestImage(DeepWaterParameters.Network.inception_bn);
    }

    @Test
    public void MOJOTestImageAlexnet() {
        MOJOTestImage(DeepWaterParameters.Network.alexnet);
    }

    @Test
    @Ignore
    public void MOJOTestImageResnet() {
        MOJOTestImage(DeepWaterParameters.Network.resnet);
    }

    @Test
    public void MOJOTestImageVGG() {
        MOJOTestImage(DeepWaterParameters.Network.vgg);
    }

    @Test
    @Ignore
    public void MOJOTestImageGooglenet() {
        MOJOTestImage(DeepWaterParameters.Network.googlenet);
    }

    private void MOJOTest(Model.Parameters.CategoricalEncodingScheme categoricalEncodingScheme, boolean z, boolean z2) {
        Frame frame = null;
        Frame frame2 = null;
        Frame frame3 = null;
        DeepWaterModel deepWaterModel = null;
        Frame frame4 = null;
        Frame frame5 = null;
        Frame frame6 = null;
        try {
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            deepWaterParameters._response_column = "CAPSULE";
            for (String str : new String[]{deepWaterParameters._response_column}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            if (z) {
                for (String str2 : new String[]{"RACE", "DPROS", "DCAPS", "GLEASON"}) {
                    Vec remove2 = frame.remove(str2);
                    frame.add(str2, remove2.toCategoricalVec());
                    remove2.remove();
                }
            }
            DKV.put(frame);
            deepWaterParameters._train = frame._key;
            deepWaterParameters._ignored_columns = new String[]{"ID"};
            deepWaterParameters._backend = getBackend();
            deepWaterParameters._seed = 12345L;
            deepWaterParameters._epochs = 50.0d;
            deepWaterParameters._categorical_encoding = categoricalEncodingScheme;
            deepWaterParameters._standardize = z2;
            deepWaterParameters._hidden = new int[]{50, 50};
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            frame4 = deepWaterModel.score(frame);
            Assert.assertTrue(deepWaterModel.testJavaScoring(frame, frame4, 0.001d));
            double auc = ModelMetricsBinomial.make(frame4.vec(2), frame.vec(deepWaterParameters._response_column)).auc();
            Assert.assertTrue(Math.abs(auc - deepWaterModel._output._training_metrics.auc()) < 0.001d);
            if (z2) {
                Assert.assertTrue(auc > 0.7d);
            }
            frame2 = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str3 : new String[]{deepWaterParameters._response_column}) {
                frame2.add(str3, frame2.remove(str3));
            }
            if (z) {
                for (String str4 : new String[]{"RACE", "DPROS", "DCAPS", "GLEASON"}) {
                    frame2.add(str4, frame2.remove(str4));
                }
            }
            frame5 = deepWaterModel.score(frame2);
            double auc2 = ModelMetricsBinomial.make(frame5.vec(2), frame2.vec(deepWaterParameters._response_column)).auc();
            Assert.assertTrue(Math.abs(auc2 - deepWaterModel._output._training_metrics.auc()) < 0.001d);
            if (z2) {
                Assert.assertTrue(auc2 > 0.7d);
            }
            frame3 = parse_test_file("smalldata/prostate/prostate.csv");
            frame6 = deepWaterModel.score(frame3);
            double auc3 = ModelMetricsBinomial.make(frame6.vec(2), frame3.vec(deepWaterParameters._response_column)).auc();
            Assert.assertTrue(Math.abs(auc3 - deepWaterModel._output._training_metrics.auc()) < 0.001d);
            if (z2) {
                Assert.assertTrue(auc3 > 0.7d);
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (frame3 != null) {
                frame3.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame4 != null) {
                frame4.remove();
            }
            if (frame5 != null) {
                frame5.remove();
            }
            if (frame6 != null) {
                frame6.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (frame3 != null) {
                frame3.remove();
            }
            if (deepWaterModel != null) {
                deepWaterModel.remove();
            }
            if (frame4 != null) {
                frame4.remove();
            }
            if (frame5 != null) {
                frame5.remove();
            }
            if (frame6 != null) {
                frame6.remove();
            }
            throw th;
        }
    }

    @Test
    public void MOJOTestNumericNonStandardized() {
        MOJOTest(Model.Parameters.CategoricalEncodingScheme.AUTO, false, false);
    }

    @Test
    public void MOJOTestNumeric() {
        MOJOTest(Model.Parameters.CategoricalEncodingScheme.AUTO, false, true);
    }

    @Test
    public void MOJOTestCatInternal() {
        MOJOTest(Model.Parameters.CategoricalEncodingScheme.OneHotInternal, true, true);
    }

    @Test
    public void MOJOTestCatExplicit() {
        MOJOTest(Model.Parameters.CategoricalEncodingScheme.OneHotExplicit, true, true);
    }

    @Test
    public void MOJOTestCatEigen() {
        MOJOTest(Model.Parameters.CategoricalEncodingScheme.Eigen, true, true);
    }

    @Test
    public void MOJOTestCatBinary() {
        MOJOTest(Model.Parameters.CategoricalEncodingScheme.Binary, true, true);
    }

    @Test
    public void testCheckpointForwards() {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        DeepWaterModel deepWaterModel2 = null;
        try {
            frame = parse_test_file("./smalldata/iris/iris.csv");
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            deepWaterParameters._train = frame._key;
            deepWaterParameters._epochs = 10.0d;
            deepWaterParameters._response_column = "C5";
            deepWaterParameters._hidden = new int[]{2, 2};
            deepWaterParameters._seed = 912559L;
            deepWaterParameters._stopping_rounds = 0;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            DeepWaterParameters clone = deepWaterParameters.clone();
            clone._epochs = 20.0d;
            clone._checkpoint = deepWaterModel._key;
            deepWaterModel2 = (DeepWaterModel) new DeepWater(clone).trainModel().get();
            Assert.assertTrue(deepWaterModel2.epoch_counter > 20.0d);
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCheckpointBackwards() {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        DeepWaterModel deepWaterModel2 = null;
        try {
            frame = parse_test_file("./smalldata/iris/iris.csv");
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            deepWaterParameters._train = frame._key;
            deepWaterParameters._epochs = 10.0d;
            deepWaterParameters._response_column = "C5";
            deepWaterParameters._hidden = new int[]{2, 2};
            deepWaterParameters._seed = 912559L;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            DeepWaterParameters clone = deepWaterParameters.clone();
            clone._epochs = 9.0d;
            clone._checkpoint = deepWaterModel._key;
            try {
                deepWaterModel2 = (DeepWaterModel) new DeepWater(clone).trainModel().get();
                Assert.fail("Should toss exception instead of reaching here");
            } catch (H2OIllegalArgumentException e) {
            }
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            throw th;
        }
    }

    @Test
    public void checkpointReporting() {
        Scope.enter();
        Frame frame = null;
        try {
            frame = ParseDataset.parse(Key.make(), new Key[]{NFSFileVec.make(FileUtils.locateFile("smalldata/logreg/prostate.csv"))._key});
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            deepWaterParameters._train = frame._key;
            deepWaterParameters._response_column = "CAPSULE";
            deepWaterParameters._activation = DeepWaterParameters.Activation.Rectifier;
            deepWaterParameters._epochs = 4.0d;
            deepWaterParameters._train_samples_per_iteration = -1L;
            deepWaterParameters._mini_batch_size = 1;
            deepWaterParameters._score_duty_cycle = 1.0d;
            deepWaterParameters._score_interval = 0.0d;
            deepWaterParameters._overwrite_with_best_model = false;
            deepWaterParameters._seed = 1234L;
            int find = frame.find("CAPSULE");
            Scope.track(frame.replace(find, frame.vecs()[find].toCategoricalVec()));
            DKV.put(frame);
            long currentTimeMillis = System.currentTimeMillis();
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
            }
            DeepWaterModel deepWaterModel = new DeepWater(deepWaterParameters).trainModel().get();
            try {
                Thread.sleep(5 * 1000);
            } catch (InterruptedException e2) {
            }
            DeepWaterParameters clone = deepWaterParameters.clone();
            clone._checkpoint = deepWaterModel._key;
            clone._epochs *= 2.0d;
            DeepWaterModel deepWaterModel2 = null;
            try {
                deepWaterModel2 = (DeepWaterModel) new DeepWater(clone).trainModel().get();
                long currentTimeMillis2 = System.currentTimeMillis();
                TwoDimTable twoDimTable = deepWaterModel2._output._scoring_history;
                double d = 0.0d;
                long j = 0;
                DateTimeFormatter forPattern = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");
                int i = 0;
                while (i < twoDimTable.getRowDim()) {
                    long parseMillis = forPattern.parseMillis((String) twoDimTable.get(i, 0));
                    Assert.assertTrue("Timestamp must be later than outside timer start", parseMillis >= currentTimeMillis);
                    Assert.assertTrue("Timestamp must be earlier than outside timer end", parseMillis <= currentTimeMillis2);
                    Assert.assertTrue("Timestamp must increase", parseMillis >= j);
                    j = parseMillis;
                    String str = (String) twoDimTable.get(i, 1);
                    try {
                        double parseDouble = Double.parseDouble(str.substring(0, str.length() - 4));
                        Assert.assertTrue("Duration must be >0: " + parseDouble, parseDouble >= 0.0d);
                        Assert.assertTrue("Duration must increase: " + d + " -> " + parseDouble, parseDouble >= d);
                        Assert.assertTrue("Duration cannot be more than outside timer delta", parseDouble <= ((double) (currentTimeMillis2 - currentTimeMillis)) / 1000.0d);
                        d = parseDouble;
                    } catch (NumberFormatException e3) {
                    }
                    Assert.assertTrue("Epoch counter must be contiguous", ((Double) twoDimTable.get(i, 3)).doubleValue() == ((double) i));
                    Assert.assertTrue("Iteration counter must match epochs", ((Integer) twoDimTable.get(i, 4)).intValue() == i);
                    i++;
                }
                try {
                    String str2 = (String) twoDimTable.get((int) deepWaterParameters._epochs, 1);
                    String substring = str2.substring(0, str2.length() - 4);
                    String str3 = (String) twoDimTable.get((int) (deepWaterParameters._epochs + 1.0d), 1);
                    Assert.assertTrue("Duration must be smooth", Double.parseDouble(str3.substring(0, str3.length() - 4)) - Double.parseDouble(substring) < ((double) (5 + 1)));
                    Assert.assertTrue("Time stamp must experience a delay", forPattern.parseMillis((String) twoDimTable.get((int) (deepWaterParameters._epochs + 1.0d), 0)) - forPattern.parseMillis((String) twoDimTable.get((int) deepWaterParameters._epochs, 0)) >= (5 - 1) * 1000);
                    String str4 = (String) twoDimTable.get((int) deepWaterParameters._epochs, 2);
                    double parseDouble2 = Double.parseDouble(str4.substring(0, str4.length() - 9));
                    String str5 = (String) twoDimTable.get((int) (deepWaterParameters._epochs + 1.0d), 2);
                    Assert.assertTrue("Speed shouldn't change more than 50%", Math.abs(Double.parseDouble(str5.substring(0, str5.length() - 9)) - parseDouble2) / parseDouble2 < 0.5d);
                } catch (NumberFormatException e4) {
                }
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (deepWaterModel2 != null) {
                    deepWaterModel2.delete();
                }
                if (frame != null) {
                    frame.remove();
                }
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (deepWaterModel2 != null) {
                    deepWaterModel2.delete();
                }
                throw th;
            }
        } catch (Throwable th2) {
            if (frame != null) {
                frame.remove();
            }
            Scope.exit(new Key[0]);
            throw th2;
        }
    }

    @Test
    public void testNumericalExplosion() {
        for (boolean z : new boolean[]{false}) {
            Frame frame = null;
            DeepWaterModel deepWaterModel = null;
            Frame frame2 = null;
            try {
                frame = parse_test_file("./smalldata/junit/two_spiral.csv");
                for (String str : new String[]{"Class"}) {
                    Vec categoricalVec = frame.vec(str).toCategoricalVec();
                    frame.remove(str).remove();
                    frame.add(str, categoricalVec);
                    DKV.put(frame);
                }
                DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
                deepWaterParameters._backend = getBackend();
                deepWaterParameters._train = frame._key;
                deepWaterParameters._epochs = 100.0d;
                deepWaterParameters._response_column = "Class";
                deepWaterParameters._autoencoder = z;
                deepWaterParameters._train_samples_per_iteration = 10L;
                deepWaterParameters._hidden = new int[]{10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10};
                deepWaterParameters._learning_rate = 1.0E10d;
                deepWaterParameters._standardize = false;
                DeepWater deepWater = new DeepWater(deepWaterParameters);
                try {
                    deepWater.trainModel().get();
                    Assert.fail("Should toss exception instead of reaching here");
                } catch (RuntimeException e) {
                }
                deepWaterModel = (DeepWaterModel) DKV.getGet(deepWater.dest());
                try {
                    frame2 = deepWaterModel.score(frame);
                    Assert.fail("Should toss exception instead of reaching here");
                } catch (RuntimeException e2) {
                }
                try {
                    deepWaterModel.getMojo();
                    Assert.fail("Should toss exception instead of reaching here");
                } catch (RuntimeException e3) {
                    System.err.println(e3.getMessage());
                }
                Assert.assertTrue(deepWaterModel.model_info()._unstable);
                Assert.assertTrue(deepWaterModel._output._job.isCrashed());
                if (frame != null) {
                    frame.delete();
                }
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame2 != null) {
                    frame2.delete();
                }
            } catch (Throwable th) {
                if (frame != null) {
                    frame.delete();
                }
                if (deepWaterModel != null) {
                    deepWaterModel.delete();
                }
                if (frame2 != null) {
                    frame2.delete();
                }
                throw th;
            }
        }
    }

    @Test
    public void textsToArrayTest() throws IOException {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("the rock is destined to be the 21st century's new \" conan \" and that he's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .");
        arrayList.add("the gorgeously elaborate continuation of \" the lord of the rings \" trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .");
        arrayList.add("effective but too-tepid biopic");
        arrayList2.add("pos");
        arrayList2.add("pos");
        arrayList2.add("pos");
        arrayList.add("simplistic , silly and tedious .");
        arrayList.add("it's so laddish and juvenile , only teenage boys could possibly find it funny .");
        arrayList.add("exploitative and largely devoid of the depth or sophistication that would make watching such a graphic treatment of the crimes bearable .");
        arrayList2.add("neg");
        arrayList2.add("neg");
        arrayList2.add("neg");
        ArrayList texts2array = StringUtils.texts2array(arrayList);
        Iterator it = texts2array.iterator();
        while (it.hasNext()) {
            System.out.println(Arrays.toString((int[]) it.next()));
        }
        System.out.println("rows " + texts2array.size() + " cols " + ((int[]) texts2array.get(0)).length);
        Assert.assertEquals(6L, texts2array.size());
        Assert.assertEquals(38L, ((int[]) texts2array.get(0)).length);
    }

    @Test
    @Ignore
    public void tweetsToArrayTest() throws IOException {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        FileInputStream fileInputStream = new FileInputStream("/home/magnus/tweets.txt");
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(fileInputStream));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                break;
            } else {
                arrayList.add(readLine);
            }
        }
        fileInputStream.close();
        FileInputStream fileInputStream2 = new FileInputStream("/home/magnus/labels.txt");
        BufferedReader bufferedReader2 = new BufferedReader(new InputStreamReader(fileInputStream2));
        while (true) {
            String readLine2 = bufferedReader2.readLine();
            if (readLine2 == null) {
                fileInputStream2.close();
                ArrayList texts2array = StringUtils.texts2array(arrayList);
                System.out.println("rows " + texts2array.size() + " cols " + ((int[]) texts2array.get(0)).length);
                Assert.assertEquals(1390L, texts2array.size());
                Assert.assertEquals(35L, ((int[]) texts2array.get(0)).length);
                return;
            }
            arrayList2.add(readLine2);
        }
    }

    @Test
    public void testCheckpointOverwriteWithBestModel() {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        DeepWaterModel deepWaterModel2 = null;
        Frame frame2 = null;
        Frame frame3 = null;
        try {
            frame = parse_test_file("./smalldata/iris/iris.csv");
            FrameSplitter frameSplitter = new FrameSplitter(frame, new double[]{0.8d}, new Key[]{Key.make("train"), Key.make("valid")}, (Key) null);
            frameSplitter.compute2();
            frame2 = frameSplitter.getResult()[0];
            frame3 = frameSplitter.getResult()[1];
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            deepWaterParameters._train = frame2._key;
            deepWaterParameters._valid = frame3._key;
            deepWaterParameters._epochs = 1.0d;
            deepWaterParameters._response_column = "C5";
            deepWaterParameters._hidden = new int[]{50, 50};
            deepWaterParameters._seed = 912559L;
            deepWaterParameters._train_samples_per_iteration = 0L;
            deepWaterParameters._score_duty_cycle = 1.0d;
            deepWaterParameters._score_interval = 0.0d;
            deepWaterParameters._stopping_rounds = 0;
            deepWaterParameters._overwrite_with_best_model = true;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            double logloss = deepWaterModel._output._validation_metrics.logloss();
            DeepWaterParameters clone = deepWaterParameters.clone();
            clone._epochs = 10.0d;
            clone._checkpoint = deepWaterModel._key;
            deepWaterModel2 = (DeepWaterModel) new DeepWater(clone).trainModel().get();
            Assert.assertTrue(deepWaterModel2._output._validation_metrics.logloss() <= logloss);
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            if (frame3 != null) {
                frame3.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            if (frame3 != null) {
                frame3.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCheckpointOverwriteWithBestModel2() {
        Frame frame = null;
        DeepWaterModel deepWaterModel = null;
        DeepWaterModel deepWaterModel2 = null;
        Frame frame2 = null;
        Frame frame3 = null;
        try {
            frame = parse_test_file("./smalldata/iris/iris.csv");
            FrameSplitter frameSplitter = new FrameSplitter(frame, new double[]{0.8d}, new Key[]{Key.make("train"), Key.make("valid")}, (Key) null);
            frameSplitter.compute2();
            frame2 = frameSplitter.getResult()[0];
            frame3 = frameSplitter.getResult()[1];
            DeepWaterParameters deepWaterParameters = new DeepWaterParameters();
            deepWaterParameters._backend = getBackend();
            deepWaterParameters._train = frame2._key;
            deepWaterParameters._valid = frame3._key;
            deepWaterParameters._epochs = 10.0d;
            deepWaterParameters._response_column = "C5";
            deepWaterParameters._hidden = new int[]{50, 50};
            deepWaterParameters._seed = 912559L;
            deepWaterParameters._train_samples_per_iteration = 0L;
            deepWaterParameters._score_duty_cycle = 1.0d;
            deepWaterParameters._score_interval = 0.0d;
            deepWaterParameters._stopping_rounds = 0;
            deepWaterParameters._overwrite_with_best_model = true;
            deepWaterModel = (DeepWaterModel) new DeepWater(deepWaterParameters).trainModel().get();
            double logloss = deepWaterModel._output._validation_metrics.logloss();
            DeepWaterParameters clone = deepWaterParameters.clone();
            clone._epochs = 20.0d;
            clone._checkpoint = deepWaterModel._key;
            deepWaterModel2 = (DeepWaterModel) new DeepWater(clone).trainModel().get();
            Assert.assertTrue(deepWaterModel2._output._validation_metrics.logloss() <= logloss);
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            if (frame3 != null) {
                frame3.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (deepWaterModel != null) {
                deepWaterModel.delete();
            }
            if (deepWaterModel2 != null) {
                deepWaterModel2.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            if (frame3 != null) {
                frame3.delete();
            }
            throw th;
        }
    }
}
