/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.ConfusionMatrix;
import hex.ConfusionMatrixTest;
import hex.Model;
import hex.ModelMetrics;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.utils.DistributionFamily;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Random;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Key;
import water.Keyed;
import water.TestUtil;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.Vec;
import water.parser.ParseDataset;
import water.rapids.Rapids;
import water.util.Log;

public class DeepLearningProstateTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        DeepLearningProstateTest.stall_till_cloudsize((int)1);
    }

    @Test
    public void run() throws Exception {
        this.runFraction(2.0E-5f);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void runFraction(float fraction) {
        long seed = 233615359L;
        Random rng = new Random(seed);
        String[] datasets = new String[2];
        int[][] responses = new int[datasets.length][];
        datasets[0] = "smalldata/logreg/prostate.csv";
        responses[0] = new int[]{1, 2, 8};
        datasets[1] = "smalldata/iris/iris.csv";
        responses[1] = new int[]{4};
        LinkedHashSet<Long> checkSums = new LinkedHashSet<Long>();
        int testcount = 0;
        int count = 0;
        for (int i = 0; i < datasets.length; ++i) {
            String dataset = datasets[i];
            for (int resp : responses[i]) {
                Frame frame = null;
                Frame vframe = null;
                try {
                    NFSFileVec nfs = TestUtil.makeNfsFileVec((String)dataset);
                    frame = ParseDataset.parse((Key)Key.make(), (Key[])new Key[]{nfs._key});
                    NFSFileVec vnfs = TestUtil.makeNfsFileVec((String)dataset);
                    vframe = ParseDataset.parse((Key)Key.make(), (Key[])new Key[]{vnfs._key});
                    boolean classification = i != 0 || resp != 2;
                    String respname = frame.name(resp);
                    if (classification && !frame.vec(resp).isCategorical()) {
                        Vec r = frame.vec(resp).toCategoricalVec();
                        frame.remove(resp).remove();
                        frame.add(respname, r);
                        DKV.put((Keyed)frame);
                        Vec vr = vframe.vec(respname).toCategoricalVec();
                        vframe.remove(respname).remove();
                        vframe.add(respname, vr);
                        DKV.put((Keyed)vframe);
                    }
                    if (classification) {
                        assert (frame.vec(respname).isCategorical());
                        assert (vframe.vec(respname).isCategorical());
                    }
                    for (DeepLearningModel.DeepLearningParameters.Loss loss : new DeepLearningModel.DeepLearningParameters.Loss[]{DeepLearningModel.DeepLearningParameters.Loss.Automatic, DeepLearningModel.DeepLearningParameters.Loss.CrossEntropy, DeepLearningModel.DeepLearningParameters.Loss.Huber, DeepLearningModel.DeepLearningParameters.Loss.Absolute, DeepLearningModel.DeepLearningParameters.Loss.Quadratic}) {
                        if (!classification && (loss == DeepLearningModel.DeepLearningParameters.Loss.CrossEntropy || loss == DeepLearningModel.DeepLearningParameters.Loss.ModifiedHuber)) continue;
                        for (DistributionFamily dist : new DistributionFamily[]{DistributionFamily.AUTO, DistributionFamily.laplace, DistributionFamily.huber, DistributionFamily.bernoulli, DistributionFamily.gaussian, DistributionFamily.poisson, DistributionFamily.tweedie, DistributionFamily.gamma}) {
                            if (classification && dist != DistributionFamily.multinomial && dist != DistributionFamily.bernoulli && dist != DistributionFamily.modified_huber || !classification && (dist == DistributionFamily.multinomial || dist == DistributionFamily.bernoulli || dist == DistributionFamily.modified_huber)) continue;
                            boolean cont = false;
                            switch (dist) {
                                case tweedie: 
                                case gamma: 
                                case poisson: {
                                    if (loss == DeepLearningModel.DeepLearningParameters.Loss.Automatic) break;
                                    cont = true;
                                    break;
                                }
                                case huber: {
                                    if (loss == DeepLearningModel.DeepLearningParameters.Loss.Huber || loss == DeepLearningModel.DeepLearningParameters.Loss.Automatic) break;
                                    cont = true;
                                    break;
                                }
                                case laplace: {
                                    if (loss == DeepLearningModel.DeepLearningParameters.Loss.Absolute || loss == DeepLearningModel.DeepLearningParameters.Loss.Automatic) break;
                                    cont = true;
                                    break;
                                }
                                case modified_huber: {
                                    if (loss == DeepLearningModel.DeepLearningParameters.Loss.ModifiedHuber || loss == DeepLearningModel.DeepLearningParameters.Loss.Automatic) break;
                                    cont = true;
                                    break;
                                }
                                case bernoulli: {
                                    if (loss == DeepLearningModel.DeepLearningParameters.Loss.CrossEntropy || loss == DeepLearningModel.DeepLearningParameters.Loss.Automatic) break;
                                    cont = true;
                                }
                            }
                            if (cont) continue;
                            for (boolean elastic_averaging : new boolean[]{true, false}) {
                                for (boolean replicate : new boolean[]{true, false}) {
                                    for (DeepLearningModel.DeepLearningParameters.Activation activation : new DeepLearningModel.DeepLearningParameters.Activation[]{DeepLearningModel.DeepLearningParameters.Activation.Tanh, DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout, DeepLearningModel.DeepLearningParameters.Activation.Rectifier, DeepLearningModel.DeepLearningParameters.Activation.RectifierWithDropout, DeepLearningModel.DeepLearningParameters.Activation.Maxout, DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout}) {
                                        boolean reproducible = false;
                                        switch (dist) {
                                            case tweedie: 
                                            case gamma: 
                                            case poisson: {
                                                reproducible = true;
                                            }
                                        }
                                        for (boolean load_balance : new boolean[]{true, false}) {
                                            for (boolean shuffle : new boolean[]{true, false}) {
                                                for (boolean balance_classes : new boolean[]{true, false}) {
                                                    for (DeepLearningModel.DeepLearningParameters.ClassSamplingMethod csm : new DeepLearningModel.DeepLearningParameters.ClassSamplingMethod[]{DeepLearningModel.DeepLearningParameters.ClassSamplingMethod.Stratified, DeepLearningModel.DeepLearningParameters.ClassSamplingMethod.Uniform}) {
                                                        for (int scoretraining : new int[]{200, 20, 0}) {
                                                            for (int scorevalidation : new int[]{200, 20, 0}) {
                                                                for (int vf : new int[]{0, 1, -1}) {
                                                                    for (int n_folds : new int[]{0, 2}) {
                                                                        if (n_folds > 0 && balance_classes) continue;
                                                                        for (boolean overwrite_with_best_model : new boolean[]{false, true}) {
                                                                            for (int train_samples_per_iteration : new int[]{-2, -1, 0, rng.nextInt(200), 500}) {
                                                                                DeepLearningModel model1 = null;
                                                                                DeepLearningModel model2 = null;
                                                                                ++count;
                                                                                if (fraction < rng.nextFloat()) continue;
                                                                                try {
                                                                                    Frame pred;
                                                                                    double[] dArray;
                                                                                    Log.info((Object[])new Object[]{"**************************)"});
                                                                                    Log.info((Object[])new Object[]{"Starting test #" + count});
                                                                                    Log.info((Object[])new Object[]{"**************************)"});
                                                                                    double epochs = 7.0 + rng.nextDouble() + (double)rng.nextInt(4);
                                                                                    int[] hidden = new int[]{3 + rng.nextInt(4), 3 + rng.nextInt(6)};
                                                                                    if (activation.name().contains("Hidden")) {
                                                                                        double[] dArray2 = new double[2];
                                                                                        dArray2[0] = rng.nextFloat();
                                                                                        dArray = dArray2;
                                                                                        dArray2[1] = rng.nextFloat();
                                                                                    } else {
                                                                                        dArray = null;
                                                                                    }
                                                                                    double[] hidden_dropout_ratios = dArray;
                                                                                    Frame valid = null;
                                                                                    if (vf == 1) {
                                                                                        valid = frame;
                                                                                    } else if (vf == -1) {
                                                                                        valid = vframe;
                                                                                    }
                                                                                    long myseed = rng.nextLong();
                                                                                    boolean replicate2 = rng.nextBoolean();
                                                                                    boolean elastic_averaging2 = rng.nextBoolean();
                                                                                    DeepLearningModel.DeepLearningParameters p = new DeepLearningModel.DeepLearningParameters();
                                                                                    Log.info((Object[])new Object[]{"Using seed: " + myseed});
                                                                                    p._train = frame._key;
                                                                                    p._response_column = respname;
                                                                                    p._valid = valid == null ? null : valid._key;
                                                                                    p._hidden = hidden;
                                                                                    p._input_dropout_ratio = 0.1;
                                                                                    p._hidden_dropout_ratios = hidden_dropout_ratios;
                                                                                    p._activation = activation;
                                                                                    p._overwrite_with_best_model = overwrite_with_best_model;
                                                                                    p._epochs = epochs;
                                                                                    p._loss = loss;
                                                                                    p._distribution = dist;
                                                                                    p._nfolds = n_folds;
                                                                                    p._seed = myseed;
                                                                                    p._train_samples_per_iteration = train_samples_per_iteration;
                                                                                    p._force_load_balance = load_balance;
                                                                                    p._replicate_training_data = replicate;
                                                                                    p._reproducible = reproducible;
                                                                                    p._shuffle_training_data = shuffle;
                                                                                    p._score_training_samples = scoretraining;
                                                                                    p._score_validation_samples = scorevalidation;
                                                                                    p._classification_stop = -1.0;
                                                                                    p._regression_stop = -1.0;
                                                                                    p._stopping_rounds = 0;
                                                                                    p._balance_classes = classification && balance_classes;
                                                                                    p._quiet_mode = true;
                                                                                    p._score_validation_sampling = csm;
                                                                                    p._elastic_averaging = elastic_averaging;
                                                                                    DeepLearning dl = new DeepLearning(p, Key.make((String)(Key.make().toString() + "first")));
                                                                                    try {
                                                                                        model1 = (DeepLearningModel)dl.trainModel().get();
                                                                                        checkSums.add(model1.checksum());
                                                                                        ++testcount;
                                                                                    }
                                                                                    catch (Throwable t) {
                                                                                        model1 = (DeepLearningModel)DKV.getGet((Key)dl.dest());
                                                                                        if (model1 != null) {
                                                                                            Assert.assertTrue((boolean)((DeepLearningModel.DeepLearningModelOutput)model1._output)._job.isCrashed());
                                                                                        }
                                                                                        throw t;
                                                                                    }
                                                                                    Log.info((Object[])new Object[]{"Trained for " + model1.epoch_counter + " epochs."});
                                                                                    assert ((p._train_samples_per_iteration <= 0L || p._train_samples_per_iteration >= frame.numRows()) && model1.epoch_counter > epochs || Math.abs(model1.epoch_counter - epochs) / epochs < 0.2);
                                                                                    if (p._train_samples_per_iteration == 0L) {
                                                                                        if (!replicate) {
                                                                                            assert ((Double)((DeepLearningModel.DeepLearningModelOutput)model1._output)._scoring_history.get(1, 3) == 1.0);
                                                                                        } else assert ((Double)((DeepLearningModel.DeepLearningModelOutput)model1._output)._scoring_history.get(1, 3) > 0.7 && (Double)((DeepLearningModel.DeepLearningModelOutput)model1._output)._scoring_history.get(1, 3) < 1.3) : "First scoring at " + ((DeepLearningModel.DeepLearningModelOutput)model1._output)._scoring_history.get(1, 3) + " epochs, should be closer to 1!\n" + model1.toString();
                                                                                    } else if (p._train_samples_per_iteration == -1L && (!replicate ? !$assertionsDisabled && (Double)((DeepLearningModel.DeepLearningModelOutput)model1._output)._scoring_history.get(1, 3) != 1.0 : !reproducible && !$assertionsDisabled && (Double)((DeepLearningModel.DeepLearningModelOutput)model1._output)._scoring_history.get(1, 3) != (double)H2O.CLOUD.size())) {
                                                                                        throw new AssertionError();
                                                                                    }
                                                                                    if (n_folds != 0 ? !$assertionsDisabled && ((DeepLearningModel.DeepLearningModelOutput)model1._output)._cross_validation_metrics == null : !$assertionsDisabled && ((DeepLearningModel.DeepLearningModelOutput)model1._output)._cross_validation_metrics != null) {
                                                                                        throw new AssertionError();
                                                                                    }
                                                                                    assert (model1.model_info().get_params()._l1 == 0.0);
                                                                                    assert (model1.model_info().get_params()._l2 == 0.0);
                                                                                    Assert.assertFalse((boolean)((DeepLearningModel.DeepLearningModelOutput)model1._output)._job.isCrashed());
                                                                                    if (n_folds != 0) continue;
                                                                                    DeepLearningModel.DeepLearningParameters p2 = new DeepLearningModel.DeepLearningParameters();
                                                                                    Assert.assertTrue(((double)model1.model_info().get_processed_total() >= (double)frame.numRows() * epochs ? 1 : 0) != 0);
                                                                                    p2._checkpoint = model1._key;
                                                                                    p2._distribution = dist;
                                                                                    p2._loss = loss;
                                                                                    p2._nfolds = n_folds;
                                                                                    p2._train = frame._key;
                                                                                    p2._activation = activation;
                                                                                    p2._hidden = hidden;
                                                                                    p2._valid = valid == null ? null : valid._key;
                                                                                    p2._l1 = 0.001;
                                                                                    p2._l2 = 0.001;
                                                                                    p2._reproducible = reproducible;
                                                                                    p2._response_column = respname;
                                                                                    p2._overwrite_with_best_model = overwrite_with_best_model;
                                                                                    p2._quiet_mode = true;
                                                                                    p2._epochs = 2.0 * epochs;
                                                                                    p2._replicate_training_data = replicate2;
                                                                                    p2._stopping_rounds = 0;
                                                                                    p2._seed = myseed;
                                                                                    p2._train_samples_per_iteration = train_samples_per_iteration;
                                                                                    p2._balance_classes = classification && balance_classes;
                                                                                    p2._elastic_averaging = elastic_averaging2;
                                                                                    DeepLearning dl2 = new DeepLearning(p2);
                                                                                    try {
                                                                                        model2 = (DeepLearningModel)dl2.trainModel().get();
                                                                                    }
                                                                                    catch (Throwable t) {
                                                                                        model2 = (DeepLearningModel)DKV.getGet((Key)dl2.dest());
                                                                                        if (model2 != null) {
                                                                                            Assert.assertTrue((boolean)((DeepLearningModel.DeepLearningModelOutput)model2._output)._job.isCrashed());
                                                                                        }
                                                                                        throw t;
                                                                                    }
                                                                                    Assert.assertTrue((boolean)((DeepLearningModel.DeepLearningModelOutput)model1._output)._job.isDone());
                                                                                    Assert.assertTrue((boolean)((DeepLearningModel.DeepLearningModelOutput)model2._output)._job.isDone());
                                                                                    assert (model1._parms != p2);
                                                                                    assert (model1.model_info().get_params() != model2.model_info().get_params());
                                                                                    assert (model1.model_info().get_params()._l1 == 0.0);
                                                                                    assert (model1.model_info().get_params()._l2 == 0.0);
                                                                                    if (!overwrite_with_best_model) {
                                                                                        Assert.assertTrue(((double)model2.model_info().get_processed_total() >= (double)(frame.numRows() * 2L) * epochs ? 1 : 0) != 0);
                                                                                    }
                                                                                    assert (p != p2);
                                                                                    assert (p != model1.model_info().get_params());
                                                                                    assert (p2 != model2.model_info().get_params());
                                                                                    if (p._loss == DeepLearningModel.DeepLearningParameters.Loss.Automatic) assert (p2._loss == DeepLearningModel.DeepLearningParameters.Loss.Automatic);
                                                                                    assert (p._hidden_dropout_ratios == null);
                                                                                    assert (p2._hidden_dropout_ratios == null);
                                                                                    if (p._activation.toString().contains("WithDropout")) {
                                                                                        assert (model1.model_info().get_params()._hidden_dropout_ratios != null);
                                                                                        assert (model2.model_info().get_params()._hidden_dropout_ratios != null);
                                                                                        assert (Arrays.equals(model1.model_info().get_params()._hidden_dropout_ratios, model2.model_info().get_params()._hidden_dropout_ratios));
                                                                                    }
                                                                                    assert (p._l1 == 0.0);
                                                                                    assert (p._l2 == 0.0);
                                                                                    assert (p2._l1 == 0.001);
                                                                                    assert (p2._l2 == 0.001);
                                                                                    assert (model1.model_info().get_params()._l1 == 0.0);
                                                                                    assert (model1.model_info().get_params()._l2 == 0.0);
                                                                                    assert (model2.model_info().get_params()._l1 == 0.001);
                                                                                    assert (model2.model_info().get_params()._l2 == 0.001);
                                                                                    if (valid == null) {
                                                                                        valid = frame;
                                                                                    }
                                                                                    if (((DeepLearningModel.DeepLearningModelOutput)model2._output).isClassifier()) {
                                                                                        pred = null;
                                                                                        try {
                                                                                            pred = model2.score(valid);
                                                                                            DKV.put((Key)Key.make((String)"pred"), (Iced)pred);
                                                                                            if (!model2.testJavaScoring(valid, pred, 1.0E-6)) {
                                                                                                model2.testJavaScoring(valid, pred, 1.0E-6);
                                                                                            }
                                                                                            Assert.assertTrue((boolean)model2.testJavaScoring(valid, pred, 1.0E-6));
                                                                                            ModelMetrics mm = ModelMetrics.getFromDKV((Model)model2, (Frame)valid);
                                                                                            if (((DeepLearningModel.DeepLearningModelOutput)model2._output).nclasses() == 2) {
                                                                                                assert (resp == 1);
                                                                                                double threshold = mm.auc_obj().defaultThreshold();
                                                                                                double error = mm.auc_obj().defaultErr();
                                                                                                Assert.assertEquals((double)new ConfusionMatrix(mm.auc_obj().defaultCM(), valid.vec(respname).domain()).err(), (double)error, (double)1.0E-15);
                                                                                                Assert.assertEquals((double)mm.cm().err(), (double)error, (double)1.0E-15);
                                                                                                Vec labels = valid.vec(respname);
                                                                                                Vec predlabels = pred.vecs()[0];
                                                                                                ConfusionMatrix cm = ConfusionMatrixTest.buildCM((Vec)labels, (Vec)predlabels);
                                                                                                Log.info((Object[])new Object[]{"CM from pre-made labels:"});
                                                                                                Log.info((Object[])new Object[]{cm.toASCII()});
                                                                                                if (Math.abs(cm.err() - error) > 0.02) {
                                                                                                    ConfusionMatrix cm2 = ConfusionMatrixTest.buildCM((Vec)labels, (Vec)predlabels);
                                                                                                    Log.info((Object[])new Object[]{cm2.toASCII()});
                                                                                                }
                                                                                                Assert.assertEquals((double)cm.err(), (double)error, (double)0.02);
                                                                                                String ast = "(as.factor (> (cols pred [2]) " + threshold + "))";
                                                                                                Frame tmp = Rapids.exec((String)ast).getFrame();
                                                                                                Vec pred2labels = tmp.vecs()[0];
                                                                                                cm = ConfusionMatrixTest.buildCM((Vec)labels, (Vec)pred2labels);
                                                                                                Log.info((Object[])new Object[]{"CM from self-made labels:"});
                                                                                                Log.info((Object[])new Object[]{cm.toASCII()});
                                                                                                Assert.assertEquals((double)cm.err(), (double)error, (double)0.02);
                                                                                                tmp.delete();
                                                                                            }
                                                                                            DKV.remove((Key)Key.make((String)"pred"));
                                                                                        }
                                                                                        finally {
                                                                                            if (pred != null) {
                                                                                                pred.delete();
                                                                                            }
                                                                                        }
                                                                                    }
                                                                                    pred = null;
                                                                                    try {
                                                                                        pred = model2.score(valid);
                                                                                        Assert.assertTrue((boolean)model2.testJavaScoring(frame, pred, 1.0E-6));
                                                                                    }
                                                                                    finally {
                                                                                        if (pred != null) {
                                                                                            pred.delete();
                                                                                        }
                                                                                    }
                                                                                    Log.info((Object[])new Object[]{"Parameters combination " + count + ": PASS"});
                                                                                }
                                                                                catch (IllegalArgumentException | H2OModelBuilderIllegalArgumentException ex) {
                                                                                    System.err.println(ex);
                                                                                    throw H2O.fail((String)"should not get here");
                                                                                }
                                                                                catch (RuntimeException t) {
                                                                                    String msg = "" + t.getMessage() + (t.getCause() == null ? "" : t.getCause().getMessage());
                                                                                    Assert.assertTrue((String)("Unexpected exception " + t + ": " + msg), (boolean)msg.contains("unstable"));
                                                                                }
                                                                                catch (AssertionError ae) {
                                                                                    throw ae;
                                                                                }
                                                                                catch (Throwable t) {
                                                                                    t.printStackTrace();
                                                                                    throw new RuntimeException(t);
                                                                                }
                                                                                finally {
                                                                                    if (model1 != null) {
                                                                                        model1.deleteCrossValidationModels();
                                                                                        model1.delete();
                                                                                    }
                                                                                    if (model2 != null) {
                                                                                        model2.deleteCrossValidationModels();
                                                                                        model2.delete();
                                                                                    }
                                                                                }
                                                                            }
                                                                        }
                                                                    }
                                                                }
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
                finally {
                    if (frame != null) {
                        frame.delete();
                    }
                    if (vframe != null) {
                        vframe.delete();
                    }
                }
            }
        }
        Log.info((Object[])new Object[]{"\n\n============================================="});
        Log.info((Object[])new Object[]{"Tested " + testcount + " out of " + count + " parameter combinations."});
        Log.info((Object[])new Object[]{"============================================="});
        if (checkSums.size() != testcount) {
            Log.info((Object[])new Object[]{"Only found " + checkSums.size() + " unique checksums."});
        }
        Assert.assertTrue((checkSums.size() == testcount ? 1 : 0) != 0);
    }

    public static class Short
    extends DeepLearningProstateTest {
        @Override
        @Test
        @Ignore
        public void run() throws Exception {
            this.runFraction(0.001f);
        }
    }

    public static class Mid
    extends DeepLearningProstateTest {
        @Override
        @Test
        @Ignore
        public void run() throws Exception {
            this.runFraction(0.01f);
        }
    }
}

