package hex.deeplearning;

import hex.deeplearning.DeepLearningModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.ParseDataset;
import water.util.Log;

/* loaded from: input_file:hex/deeplearning/DeepLearningAutoEncoderCategoricalTest.class */
public class DeepLearningAutoEncoderCategoricalTest extends TestUtil {
    static final String PATH = "smalldata/airlines/AirlinesTrain.csv.zip";

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void run() {
        Frame parse = ParseDataset.parse(Key.make("train.hex"), new Key[]{TestUtil.makeNfsFileVec(PATH)._key});
        DeepLearningModel.DeepLearningParameters deepLearningParameters = new DeepLearningModel.DeepLearningParameters();
        deepLearningParameters._train = parse._key;
        deepLearningParameters._autoencoder = true;
        deepLearningParameters._response_column = parse.names()[parse.names().length - 1];
        deepLearningParameters._seed = 912559L;
        deepLearningParameters._hidden = new int[]{10, 5, 3};
        deepLearningParameters._adaptive_rate = true;
        deepLearningParameters._l1 = 1.0E-4d;
        deepLearningParameters._activation = DeepLearningModel.DeepLearningParameters.Activation.Tanh;
        deepLearningParameters._max_w2 = 10.0f;
        deepLearningParameters._train_samples_per_iteration = -1L;
        deepLearningParameters._loss = DeepLearningModel.DeepLearningParameters.Loss.Huber;
        deepLearningParameters._epochs = 0.2d;
        deepLearningParameters._force_load_balance = true;
        deepLearningParameters._score_training_samples = 0L;
        deepLearningParameters._score_validation_samples = 0L;
        deepLearningParameters._reproducible = true;
        DeepLearningModel deepLearningModel = new DeepLearning(deepLearningParameters).trainModel().get();
        StringBuilder sb = new StringBuilder();
        sb.append("Verifying results.\n");
        sb.append("Reported mean reconstruction error: " + deepLearningModel.mse() + "\n");
        Frame scoreAutoEncoder = deepLearningModel.scoreAutoEncoder(parse, Key.make(), true);
        sb.append("Reconstruction error per feature: " + scoreAutoEncoder.toString() + "\n");
        scoreAutoEncoder.remove();
        Frame scoreAutoEncoder2 = deepLearningModel.scoreAutoEncoder(parse, Key.make(), false);
        Vec anyVec = scoreAutoEncoder2.anyVec();
        sb.append("Actual   mean reconstruction error: " + anyVec.mean() + "\n");
        double numRows = 1.0d - (5.0d / parse.numRows());
        sb.append("The following training points are reconstructed with an error above the " + (numRows * 100.0d) + "-th percentile - potential \"outliers\" in testing data.\n");
        double calcOutlierThreshold = deepLearningModel.calcOutlierThreshold(anyVec, numRows);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= anyVec.length()) {
                break;
            }
            if (anyVec.at(j2) > calcOutlierThreshold) {
                sb.append(String.format("row %d : l2vec error = %5f\n", Long.valueOf(j2), Double.valueOf(anyVec.at(j2))));
            }
            j = j2 + 1;
        }
        Log.info(new Object[]{sb.toString()});
        Assert.assertEquals(anyVec.mean(), deepLearningModel.mse(), 1.0E-8d * deepLearningModel.mse());
        Log.info(new Object[]{"Creating full reconstruction."});
        Frame score = deepLearningModel.score(parse);
        Assert.assertTrue(deepLearningModel.testJavaScoring(parse, score, 1.0E-5d));
        Frame scoreDeepFeatures = deepLearningModel.scoreDeepFeatures(parse, 0);
        Assert.assertTrue(scoreDeepFeatures.numCols() == 10);
        Assert.assertTrue(scoreDeepFeatures.numRows() == parse.numRows());
        scoreDeepFeatures.delete();
        Frame scoreDeepFeatures2 = deepLearningModel.scoreDeepFeatures(parse, 1);
        Assert.assertTrue(scoreDeepFeatures2.numCols() == 5);
        Assert.assertTrue(scoreDeepFeatures2.numRows() == parse.numRows());
        scoreDeepFeatures2.delete();
        Frame scoreDeepFeatures3 = deepLearningModel.scoreDeepFeatures(parse, 2);
        Assert.assertTrue(scoreDeepFeatures3.numCols() == 3);
        Assert.assertTrue(scoreDeepFeatures3.numRows() == parse.numRows());
        scoreDeepFeatures3.delete();
        score.delete();
        parse.delete();
        deepLearningModel.delete();
        scoreAutoEncoder2.delete();
    }
}
