package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.class */
public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
    private transient SparkSession spark;

    @Before
    public void setUp() {
        this.spark = SparkSession.builder().master("local").appName("JavaLogisticRegressionSuite").getOrCreate();
    }

    @After
    public void tearDown() {
        this.spark.stop();
        this.spark = null;
    }

    @Test
    public void testMLPC() {
        Dataset createDataFrame = this.spark.createDataFrame(Arrays.asList(new LabeledPoint(0.0d, Vectors.dense(0.0d, new double[]{0.0d})), new LabeledPoint(1.0d, Vectors.dense(0.0d, new double[]{1.0d})), new LabeledPoint(1.0d, Vectors.dense(1.0d, new double[]{0.0d})), new LabeledPoint(0.0d, Vectors.dense(1.0d, new double[]{1.0d}))), LabeledPoint.class);
        for (Row row : new MultilayerPerceptronClassifier().setLayers(new int[]{2, 5, 2}).setBlockSize(1).setSeed(123L).setMaxIter(100).fit(createDataFrame).transform(createDataFrame).select("prediction", new String[]{"label"}).collectAsList()) {
            Assert.assertEquals((int) row.getDouble(0), (int) row.getDouble(1));
        }
    }
}
