package hex.deeplearning;

import hex.deeplearning.DeepLearningModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import java.io.IOException;
import java.util.HashSet;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.parser.ParseDataset;
import water.util.Log;

/* loaded from: input_file:hex/deeplearning/DeepLearningAutoEncoderTest.class */
public class DeepLearningAutoEncoderTest extends TestUtil {
    static final String PATH = "smalldata/anomaly/ecg_discord_train.csv";
    static final String PATH2 = "smalldata/anomaly/ecg_discord_test.csv";

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void run() {
        Frame frame = null;
        Frame frame2 = null;
        try {
            frame = ParseDataset.parse(Key.make("train.hex"), new Key[]{TestUtil.makeNfsFileVec(PATH)._key});
            frame2 = ParseDataset.parse(Key.make("test.hex"), new Key[]{TestUtil.makeNfsFileVec(PATH2)._key});
            for (float f : new float[]{0.0f, 0.1f}) {
                DeepLearningModel.DeepLearningParameters deepLearningParameters = new DeepLearningModel.DeepLearningParameters();
                deepLearningParameters._train = frame._key;
                deepLearningParameters._valid = frame2._key;
                deepLearningParameters._autoencoder = true;
                deepLearningParameters._response_column = frame.names()[frame.names().length - 1];
                deepLearningParameters._seed = 912559L;
                deepLearningParameters._hidden = new int[]{37, 12};
                deepLearningParameters._adaptive_rate = true;
                deepLearningParameters._train_samples_per_iteration = -1L;
                deepLearningParameters._sparsity_beta = f;
                deepLearningParameters._average_activation = -0.7d;
                deepLearningParameters._l1 = 1.0E-4d;
                deepLearningParameters._activation = DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout;
                deepLearningParameters._loss = DeepLearningModel.DeepLearningParameters.Loss.Absolute;
                deepLearningParameters._epochs = 13.3d;
                deepLearningParameters._force_load_balance = true;
                deepLearningParameters._elastic_averaging = false;
                DeepLearningModel deepLearningModel = new DeepLearning(deepLearningParameters).trainModel().get();
                deepLearningParameters._standardize = false;
                DeepLearningModel deepLearningModel2 = new DeepLearning(deepLearningParameters).trainModel().get();
                Frame frame3 = null;
                Frame frame4 = null;
                StringBuilder sb = new StringBuilder();
                try {
                    sb.append("Verifying results.\n");
                    Frame scoreAutoEncoder = deepLearningModel.scoreAutoEncoder(frame2, Key.make(), true);
                    sb.append("Reconstruction error per feature (test): ").append(scoreAutoEncoder.toString()).append("\n");
                    scoreAutoEncoder.remove();
                    frame4 = deepLearningModel.scoreAutoEncoder(frame2, Key.make(), false);
                    Vec anyVec = frame4.anyVec();
                    sb.append("Mean reconstruction error (test): ").append(anyVec.mean()).append("\n");
                    Assert.assertEquals(anyVec.mean(), deepLearningModel.mse(), 1.0E-7d);
                    Assert.assertTrue("too big a reconstruction error: " + anyVec.mean(), anyVec.mean() < 2.0d);
                    anyVec.remove();
                    Frame score = deepLearningModel.score(frame);
                    Assert.assertTrue(deepLearningModel.testJavaScoring(frame, score, 1.0E-6d));
                    Frame scoreAutoEncoder2 = deepLearningModel.scoreAutoEncoder(frame, Key.make(), false);
                    Vec anyVec2 = scoreAutoEncoder2.anyVec();
                    double d = 0.0d;
                    for (int i = 0; i < score.numRows(); i++) {
                        double d2 = 0.0d;
                        for (int i2 = 0; i2 < score.numCols(); i2++) {
                            d2 += Math.pow((score.vec(i2).at(i) - frame.vec(i2).at(i)) * deepLearningModel.model_info().data_info()._normMul[i2], 2.0d);
                        }
                        d += d2 / score.numCols();
                    }
                    double numRows = d / score.numRows();
                    score.delete();
                    sb.append("Mean reconstruction error (train): ").append(anyVec2.mean()).append("\n");
                    Assert.assertEquals(deepLearningModel._output.errors.scored_train._mse, numRows, 1.0E-7d);
                    sb.append("The following training points are reconstructed with an error above the ").append(0.95d * 100.0d).append("-th percentile - check for \"goodness\" of training data.\n");
                    double calcOutlierThreshold = deepLearningModel.calcOutlierThreshold(anyVec2, 0.95d);
                    for (long j = 0; j < anyVec2.length(); j++) {
                        if (anyVec2.at(j) > calcOutlierThreshold) {
                            sb.append(String.format("row %d : l2_train error = %5f\n", Long.valueOf(j), Double.valueOf(anyVec2.at(j))));
                        }
                    }
                    frame4.remove();
                    Frame scoreAutoEncoder3 = deepLearningModel.scoreAutoEncoder(frame2, Key.make(), false);
                    Vec anyVec3 = scoreAutoEncoder3.anyVec();
                    double d3 = 10.0d * calcOutlierThreshold;
                    sb.append("\nFinding outliers.\n");
                    sb.append("Mean reconstruction error (test): ").append(anyVec3.mean()).append("\n");
                    Frame scoreDeepFeatures = deepLearningModel.scoreDeepFeatures(frame2, 0);
                    Assert.assertTrue(scoreDeepFeatures.numCols() == 37);
                    Assert.assertTrue(scoreDeepFeatures.numRows() == frame2.numRows());
                    scoreDeepFeatures.delete();
                    Frame scoreDeepFeatures2 = deepLearningModel.scoreDeepFeatures(frame2, 1);
                    Assert.assertTrue(scoreDeepFeatures2.numCols() == 12);
                    Assert.assertTrue(scoreDeepFeatures2.numRows() == frame2.numRows());
                    scoreDeepFeatures2.delete();
                    sb.append("The following test points are reconstructed with an error greater than ").append(10.0d).append(" times the mean reconstruction error of the training data:\n");
                    HashSet hashSet = new HashSet();
                    for (long j2 = 0; j2 < anyVec3.length(); j2++) {
                        if (anyVec3.at(j2) > d3) {
                            hashSet.add(Long.valueOf(j2));
                            sb.append(String.format("row %d : l2 error = %5f\n", Long.valueOf(j2), Double.valueOf(anyVec3.at(j2))));
                        }
                    }
                    Assert.assertTrue(hashSet.contains(new Long(20L)));
                    Assert.assertTrue(hashSet.contains(new Long(21L)));
                    Assert.assertTrue(hashSet.contains(new Long(22L)));
                    Assert.assertTrue(hashSet.size() == 3);
                    try {
                        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(deepLearningModel.toMojo());
                        EasyPredictModelWrapper easyPredictModelWrapper2 = new EasyPredictModelWrapper(deepLearningModel2.toMojo());
                        double d4 = 0.0d;
                        double d5 = 0.0d;
                        for (int i3 = 0; i3 < frame.numRows(); i3++) {
                            RowData rowData = new RowData();
                            BufferedString bufferedString = new BufferedString();
                            for (int i4 = 0; i4 < frame.numCols(); i4++) {
                                if (frame.vec(i4).isCategorical()) {
                                    rowData.put(frame.names()[i4], frame.vec(i4).atStr(bufferedString, i3).toString());
                                } else {
                                    rowData.put(frame.names()[i4], Double.valueOf(frame.vec(i4).at(i3)));
                                }
                            }
                            d4 += easyPredictModelWrapper.predictAutoEncoder(rowData).mse;
                            d5 += easyPredictModelWrapper2.predictAutoEncoder(rowData).mse;
                        }
                        double numRows2 = d4 / frame.numRows();
                        sb.append("Mojo mean reconstruction error (train): ").append(numRows2).append("\n");
                        sb.append("Mean reconstruction error should be the same from model compare to mojo model reconstruction error: ");
                        sb.append(numRows).append(" == ").append(numRows2).append("\n");
                        Assert.assertEquals(numRows, numRows2, 1.0E-7d);
                        double numRows3 = d5 / frame.numRows();
                        sb.append("Mojo mean reconstruction error (train): ").append(numRows3).append("\n");
                        sb.append("Mean reconstruction error should be the same from model compare to mojo model reconstruction error: ");
                        sb.append(deepLearningModel2._output.errors.scored_train._mse).append(" == ").append(numRows3).append("\n");
                        Assert.assertEquals(deepLearningModel2._output.errors.scored_train._mse, numRows3, 1.0E-7d);
                    } catch (IOException e) {
                        Assert.fail("IOException when testing mojo mean reconstruction error: " + e.toString());
                    } catch (PredictException e2) {
                        Assert.fail("PredictException when testing mojo mean reconstruction error: " + e2.toString());
                    }
                    Log.info(new Object[]{sb});
                    if (deepLearningModel2 != null) {
                        deepLearningModel2.delete();
                    }
                    if (deepLearningModel != null) {
                        deepLearningModel.delete();
                    }
                    if (scoreAutoEncoder2 != null) {
                        scoreAutoEncoder2.delete();
                    }
                    if (scoreAutoEncoder3 != null) {
                        scoreAutoEncoder3.delete();
                    }
                } catch (Throwable th) {
                    Log.info(new Object[]{sb});
                    if (deepLearningModel2 != null) {
                        deepLearningModel2.delete();
                    }
                    if (deepLearningModel != null) {
                        deepLearningModel.delete();
                    }
                    if (0 != 0) {
                        frame3.delete();
                    }
                    if (frame4 != null) {
                        frame4.delete();
                    }
                    throw th;
                }
            }
            if (frame != null) {
                frame.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
        } catch (Throwable th2) {
            if (frame != null) {
                frame.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            throw th2;
        }
    }
}
