/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.GenModel;
import hex.genmodel.algos.deeplearning.DeeplearningMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import java.io.IOException;
import java.util.HashSet;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.parser.ParseDataset;
import water.util.Log;

public class DeepLearningAutoEncoderTest
extends TestUtil {
    static final String PATH = "smalldata/anomaly/ecg_discord_train.csv";
    static final String PATH2 = "smalldata/anomaly/ecg_discord_test.csv";

    @BeforeClass
    public static void setup() {
        DeepLearningAutoEncoderTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void run() {
        long seed = 912559L;
        Frame train = null;
        Frame test = null;
        try {
            NFSFileVec nfs = TestUtil.makeNfsFileVec((String)PATH);
            train = ParseDataset.parse((Key)Key.make((String)"train.hex"), (Key[])new Key[]{nfs._key});
            NFSFileVec nfs2 = TestUtil.makeNfsFileVec((String)PATH2);
            test = ParseDataset.parse((Key)Key.make((String)"test.hex"), (Key[])new Key[]{nfs2._key});
            for (float sparsity_beta : new float[]{0.0f, 0.1f}) {
                DeepLearningModel.DeepLearningParameters p = new DeepLearningModel.DeepLearningParameters();
                p._train = train._key;
                p._valid = test._key;
                p._autoencoder = true;
                p._response_column = train.names()[train.names().length - 1];
                p._seed = seed;
                p._hidden = new int[]{37, 12};
                p._adaptive_rate = true;
                p._train_samples_per_iteration = -1L;
                p._sparsity_beta = sparsity_beta;
                p._average_activation = -0.7;
                p._l1 = 1.0E-4;
                p._activation = DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout;
                p._loss = DeepLearningModel.DeepLearningParameters.Loss.Absolute;
                p._epochs = 13.3;
                p._force_load_balance = true;
                p._elastic_averaging = false;
                DeepLearning dl = new DeepLearning(p);
                DeepLearningModel mymodel = (DeepLearningModel)dl.trainModel().get();
                p._standardize = false;
                DeepLearning dlNoStand = new DeepLearning(p);
                DeepLearningModel mymodelNoStand = (DeepLearningModel)dlNoStand.trainModel().get();
                Frame l2_frame_train = null;
                Frame l2_frame_test = null;
                StringBuilder sb = new StringBuilder();
                try {
                    sb.append("Verifying results.\n");
                    double quantile = 0.95;
                    l2_frame_test = mymodel.scoreAutoEncoder(test, Key.make(), true);
                    sb.append("Reconstruction error per feature (test): ").append(l2_frame_test.toString()).append("\n");
                    l2_frame_test.remove();
                    l2_frame_test = mymodel.scoreAutoEncoder(test, Key.make(), false);
                    Vec l2_test = l2_frame_test.anyVec();
                    sb.append("Mean reconstruction error (test): ").append(l2_test.mean()).append("\n");
                    Assert.assertEquals((double)l2_test.mean(), (double)mymodel.mse(), (double)1.0E-7);
                    Assert.assertTrue((String)("too big a reconstruction error: " + l2_test.mean()), (l2_test.mean() < 2.0 ? 1 : 0) != 0);
                    l2_test.remove();
                    Frame reconstr = mymodel.score(train);
                    Assert.assertTrue((boolean)mymodel.testJavaScoring(train, reconstr, 1.0E-6));
                    l2_frame_train = mymodel.scoreAutoEncoder(train, Key.make(), false);
                    Vec l2_train = l2_frame_train.anyVec();
                    double mean_l2 = 0.0;
                    int r = 0;
                    while ((long)r < reconstr.numRows()) {
                        double my_l2 = 0.0;
                        for (int c = 0; c < reconstr.numCols(); ++c) {
                            my_l2 += Math.pow((reconstr.vec(c).at((long)r) - train.vec(c).at((long)r)) * mymodel.model_info().data_info()._normMul[c], 2.0);
                        }
                        mean_l2 += (my_l2 /= (double)reconstr.numCols());
                        ++r;
                    }
                    reconstr.delete();
                    sb.append("Mean reconstruction error (train): ").append(l2_train.mean()).append("\n");
                    Assert.assertEquals((double)((DeepLearningModel.DeepLearningModelOutput)mymodel._output).errors.scored_train._mse, (double)(mean_l2 /= (double)reconstr.numRows()), (double)1.0E-7);
                    sb.append("The following training points are reconstructed with an error above the ").append(quantile * 100.0).append("-th percentile - check for \"goodness\" of training data.\n");
                    double thresh_train = mymodel.calcOutlierThreshold(l2_train, quantile);
                    for (long i = 0L; i < l2_train.length(); ++i) {
                        if (!(l2_train.at(i) > thresh_train)) continue;
                        sb.append(String.format("row %d : l2_train error = %5f\n", i, l2_train.at(i)));
                    }
                    l2_frame_test.remove();
                    l2_frame_test = mymodel.scoreAutoEncoder(test, Key.make(), false);
                    l2_test = l2_frame_test.anyVec();
                    double mult = 10.0;
                    double thresh_test = mult * thresh_train;
                    sb.append("\nFinding outliers.\n");
                    sb.append("Mean reconstruction error (test): ").append(l2_test.mean()).append("\n");
                    Frame df1 = mymodel.scoreDeepFeatures(test, 0);
                    Assert.assertTrue((df1.numCols() == 37 ? 1 : 0) != 0);
                    Assert.assertTrue((df1.numRows() == test.numRows() ? 1 : 0) != 0);
                    df1.delete();
                    Frame df2 = mymodel.scoreDeepFeatures(test, 1);
                    Assert.assertTrue((df2.numCols() == 12 ? 1 : 0) != 0);
                    Assert.assertTrue((df2.numRows() == test.numRows() ? 1 : 0) != 0);
                    df2.delete();
                    sb.append("The following test points are reconstructed with an error greater than ").append(mult).append(" times the mean reconstruction error of the training data:\n");
                    HashSet<Long> outliers = new HashSet<Long>();
                    for (long i = 0L; i < l2_test.length(); ++i) {
                        if (!(l2_test.at(i) > thresh_test)) continue;
                        outliers.add(i);
                        sb.append(String.format("row %d : l2 error = %5f\n", i, l2_test.at(i)));
                    }
                    Assert.assertTrue((boolean)outliers.contains(new Long(20L)));
                    Assert.assertTrue((boolean)outliers.contains(new Long(21L)));
                    Assert.assertTrue((boolean)outliers.contains(new Long(22L)));
                    Assert.assertTrue((outliers.size() == 3 ? 1 : 0) != 0);
                    try {
                        DeeplearningMojoModel mojoModel = (DeeplearningMojoModel)mymodel.toMojo();
                        EasyPredictModelWrapper model = new EasyPredictModelWrapper((GenModel)mojoModel);
                        DeeplearningMojoModel mojoModelNoStand = (DeeplearningMojoModel)mymodelNoStand.toMojo();
                        EasyPredictModelWrapper modelNoStand = new EasyPredictModelWrapper((GenModel)mojoModelNoStand);
                        double calcNormMse = 0.0;
                        double calcNormMseNoStand = 0.0;
                        int r2 = 0;
                        while ((long)r2 < train.numRows()) {
                            RowData tmpRow = new RowData();
                            BufferedString bStr = new BufferedString();
                            for (int c = 0; c < train.numCols(); ++c) {
                                if (train.vec(c).isCategorical()) {
                                    tmpRow.put((Object)train.names()[c], (Object)train.vec(c).atStr(bStr, (long)r2).toString());
                                    continue;
                                }
                                tmpRow.put((Object)train.names()[c], (Object)train.vec(c).at((long)r2));
                            }
                            AutoEncoderModelPrediction tmpPrediction = model.predictAutoEncoder(tmpRow);
                            calcNormMse += tmpPrediction.mse;
                            AutoEncoderModelPrediction tmpPredictionNoStand = modelNoStand.predictAutoEncoder(tmpRow);
                            calcNormMseNoStand += tmpPredictionNoStand.mse;
                            ++r2;
                        }
                        double mojoMeanError = calcNormMse / (double)train.numRows();
                        sb.append("Mojo mean reconstruction error (train): ").append(mojoMeanError).append("\n");
                        sb.append("Mean reconstruction error should be the same from model compare to mojo model reconstruction error: ");
                        sb.append(mean_l2).append(" == ").append(mojoMeanError).append("\n");
                        Assert.assertEquals((double)mean_l2, (double)mojoMeanError, (double)1.0E-7);
                        double mojoMeanErrorNoStand = calcNormMseNoStand / (double)train.numRows();
                        sb.append("Mojo mean reconstruction error (train): ").append(mojoMeanErrorNoStand).append("\n");
                        sb.append("Mean reconstruction error should be the same from model compare to mojo model reconstruction error: ");
                        sb.append(((DeepLearningModel.DeepLearningModelOutput)mymodelNoStand._output).errors.scored_train._mse).append(" == ").append(mojoMeanErrorNoStand).append("\n");
                        Assert.assertEquals((double)((DeepLearningModel.DeepLearningModelOutput)mymodelNoStand._output).errors.scored_train._mse, (double)mojoMeanErrorNoStand, (double)1.0E-7);
                    }
                    catch (IOException error) {
                        Assert.fail((String)("IOException when testing mojo mean reconstruction error: " + error.toString()));
                    }
                    catch (PredictException error) {
                        Assert.fail((String)("PredictException when testing mojo mean reconstruction error: " + error.toString()));
                    }
                }
                catch (Throwable throwable) {
                    Log.info((Object[])new Object[]{sb});
                    if (mymodelNoStand != null) {
                        mymodelNoStand.delete();
                    }
                    if (mymodel != null) {
                        mymodel.delete();
                    }
                    if (l2_frame_train != null) {
                        l2_frame_train.delete();
                    }
                    if (l2_frame_test != null) {
                        l2_frame_test.delete();
                    }
                    throw throwable;
                }
                Log.info((Object[])new Object[]{sb});
                if (mymodelNoStand != null) {
                    mymodelNoStand.delete();
                }
                if (mymodel != null) {
                    mymodel.delete();
                }
                if (l2_frame_train != null) {
                    l2_frame_train.delete();
                }
                if (l2_frame_test == null) continue;
                l2_frame_test.delete();
            }
        }
        finally {
            if (train != null) {
                train.delete();
            }
            if (test != null) {
                test.delete();
            }
        }
    }
}

