/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.FrameSplitter;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import java.util.Arrays;
import java.util.TreeMap;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.Vec;
import water.parser.ParseDataset;
import water.util.FrameUtils;
import water.util.Log;

public class DeepLearningMissingTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        DeepLearningMissingTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void run() {
        long seed = 1234L;
        DeepLearningModel mymodel = null;
        Frame train = null;
        Frame test = null;
        Frame data = null;
        Log.info((Object[])new Object[]{""});
        Log.info((Object[])new Object[]{"STARTING."});
        Log.info((Object[])new Object[]{"Using seed " + seed});
        TreeMap<DeepLearningModel.DeepLearningParameters.MissingValuesHandling, Double> sumErr = new TreeMap<DeepLearningModel.DeepLearningParameters.MissingValuesHandling, Double>();
        StringBuilder sb = new StringBuilder();
        for (DeepLearningModel.DeepLearningParameters.MissingValuesHandling mvh : new DeepLearningModel.DeepLearningParameters.MissingValuesHandling[]{DeepLearningModel.DeepLearningParameters.MissingValuesHandling.MeanImputation, DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip}) {
            double sumloss = 0.0;
            TreeMap<Double, Double> map = new TreeMap<Double, Double>();
            for (double missing_fraction : new double[]{0.0, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99}) {
                double loss = 0.0;
                try {
                    Scope.enter();
                    NFSFileVec nfs = NFSFileVec.make((String)"smalldata/junit/weather.csv");
                    data = ParseDataset.parse((Key)Key.make((String)"data.hex"), (Key[])new Key[]{nfs._key});
                    Log.info((Object[])new Object[]{"FrameSplitting"});
                    FrameSplitter fs = new FrameSplitter(data, new double[]{0.75}, FrameUtils.generateNumKeys((Key)data._key, (int)2), null);
                    H2O.submitTask((H2O.H2OCountedCompleter)fs);
                    Frame[] train_test = fs.getResult();
                    train = train_test[0];
                    test = train_test[1];
                    Log.info((Object[])new Object[]{"Done..."});
                    if (missing_fraction > 0.0) {
                        Frame frtmp = new Frame(Key.make(), train.names(), train.vecs());
                        frtmp.remove(frtmp.numCols() - 1);
                        DKV.put((Key)frtmp._key, (Iced)frtmp);
                        FrameUtils.MissingInserter j = new FrameUtils.MissingInserter(frtmp._key, seed, missing_fraction);
                        j.execImpl().get();
                        DKV.remove((Key)frtmp._key);
                    }
                    DeepLearningModel.DeepLearningParameters p = new DeepLearningModel.DeepLearningParameters();
                    p._train = train._key;
                    p._valid = test._key;
                    p._response_column = train._names[train.numCols() - 1];
                    p._ignored_columns = new String[]{train._names[1], train._names[22]};
                    p._missing_values_handling = mvh;
                    p._loss = DeepLearningModel.DeepLearningParameters.Loss.CrossEntropy;
                    p._activation = DeepLearningModel.DeepLearningParameters.Activation.Rectifier;
                    p._hidden = new int[]{50, 50};
                    p._l1 = 1.0E-5;
                    p._input_dropout_ratio = 0.2;
                    p._epochs = 3.0;
                    p._reproducible = true;
                    p._seed = seed;
                    p._elastic_averaging = false;
                    int ri = train.numCols() - 1;
                    int ci = test.find(p._response_column);
                    Scope.track((Vec)train.replace(ri, train.vecs()[ri].toCategoricalVec()));
                    Scope.track((Vec)test.replace(ci, test.vecs()[ci].toCategoricalVec()));
                    DKV.put((Keyed)train);
                    DKV.put((Keyed)test);
                    DeepLearning dl = new DeepLearning(p);
                    Log.info((Object[])new Object[]{"Starting with " + missing_fraction * 100.0 + "% missing values added."});
                    mymodel = (DeepLearningModel)dl.trainModel().get();
                    loss = mymodel.loss();
                    Log.info((Object[])new Object[]{"Missing " + missing_fraction * 100.0 + "% -> logloss: " + loss});
                }
                catch (Throwable t) {
                    t.printStackTrace();
                    loss = 100.0;
                }
                finally {
                    Scope.exit((Key[])new Key[0]);
                    if (mymodel != null) {
                        mymodel.delete();
                    }
                    if (train != null) {
                        train.delete();
                    }
                    if (test != null) {
                        test.delete();
                    }
                    if (data != null) {
                        data.delete();
                    }
                }
                map.put(missing_fraction, loss);
                sumloss += loss;
            }
            sb.append("\nMethod: ").append(mvh.toString()).append("\n");
            sb.append("missing fraction --> loss\n");
            for (String s : Arrays.toString(map.entrySet().toArray()).split(",")) {
                sb.append(s.replace("=", " --> ")).append("\n");
            }
            sb.append('\n');
            sb.append("sum loss: ").append(sumloss).append("\n");
            sumErr.put(mvh, sumloss);
        }
        Log.info((Object[])new Object[]{sb.toString()});
        Assert.assertEquals((double)405.5017, (double)((Double)sumErr.get(DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip)), (double)0.01);
        Assert.assertEquals((double)3.914915, (double)((Double)sumErr.get(DeepLearningModel.DeepLearningParameters.MissingValuesHandling.MeanImputation)), (double)0.001);
    }
}

