package org.apache.mahout.classifier.mlp;

import java.io.File;
import org.apache.mahout.common.MahoutTestCase;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.class */
public class TrainMultilayerPerceptronTest extends MahoutTestCase {
    @Test
    public void testIrisDataset() throws Exception {
        File testTempFile = getTestTempFile("mlp.model");
        File testTempFile2 = getTestTempFile("iris.csv");
        writeLines(testTempFile2, Datasets.IRIS);
        TrainMultilayerPerceptron.main(new String[]{"-i", testTempFile2.getAbsolutePath(), "-sh", "-labels", "setosa", "versicolor", "virginica", "-mo", testTempFile.getAbsolutePath(), "-u", "-ls", "4", "8", "3"});
        assertTrue(testTempFile.exists());
    }

    @Test
    public void initializeModelWithDifferentParameters() throws Exception {
        File testTempFile = getTestTempFile("mlp.model");
        File testTempFile2 = getTestTempFile("iris.csv");
        writeLines(testTempFile2, Datasets.IRIS);
        MultilayerPerceptron trainModel = trainModel(new String[]{"-i", testTempFile2.getAbsolutePath(), "-sh", "-labels", "setosa", "versicolor", "virginica", "-mo", testTempFile.getAbsolutePath(), "-u", "-ls", "4", "8", "3", "-l", "0.2", "-m", "0.35", "-r", "0.0001"}, testTempFile);
        assertEquals(0.2d, trainModel.getLearningRate(), 1.0E-6d);
        assertEquals(0.35d, trainModel.getMomentumWeight(), 1.0E-6d);
        assertEquals(1.0E-4d, trainModel.getRegularizationWeight(), 1.0E-6d);
        assertEquals(4L, trainModel.getLayerSize(0) - 1);
        assertEquals(8L, trainModel.getLayerSize(1) - 1);
        assertEquals(3L, trainModel.getLayerSize(2));
        File testTempFile3 = getTestTempFile("mlp.model");
        MultilayerPerceptron trainModel2 = trainModel(new String[]{"-i", testTempFile2.getAbsolutePath(), "-sh", "-labels", "setosa", "versicolor", "virginica", "-mo", testTempFile3.getAbsolutePath(), "-ls", "4", "10", "18", "3"}, testTempFile3);
        assertEquals(0.5d, trainModel2.getLearningRate(), 1.0E-6d);
        assertEquals(0.1d, trainModel2.getMomentumWeight(), 1.0E-6d);
        assertEquals(0.0d, trainModel2.getRegularizationWeight(), 1.0E-6d);
        assertEquals(4L, trainModel2.getLayerSize(0) - 1);
        assertEquals(10L, trainModel2.getLayerSize(1) - 1);
        assertEquals(18L, trainModel2.getLayerSize(2) - 1);
        assertEquals(3L, trainModel2.getLayerSize(3));
    }

    private MultilayerPerceptron trainModel(String[] strArr, File file) throws Exception {
        TrainMultilayerPerceptron.main(strArr);
        return new MultilayerPerceptron(file.getAbsolutePath());
    }
}
