package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.Iterator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/classification/JavaLogisticRegressionSuite.class */
public class JavaLogisticRegressionSuite implements Serializable {
    private transient SparkSession spark;
    private transient JavaSparkContext jsc;
    private transient Dataset<Row> dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;
    private double eps = 1.0E-5d;

    @Before
    public void setUp() {
        this.spark = SparkSession.builder().master("local").appName("JavaLogisticRegressionSuite").getOrCreate();
        this.jsc = new JavaSparkContext(this.spark.sparkContext());
        this.datasetRDD = this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(1.0d, 1.0d, 100, 42), 2);
        this.dataset = this.spark.createDataFrame(this.datasetRDD, LabeledPoint.class);
        this.dataset.createOrReplaceTempView("dataset");
    }

    @After
    public void tearDown() {
        this.spark.stop();
        this.spark = null;
    }

    @Test
    public void logisticRegressionDefaultParams() {
        LogisticRegression logisticRegression = new LogisticRegression();
        Assert.assertEquals(logisticRegression.getLabelCol(), "label");
        LogisticRegressionModel fit = logisticRegression.fit(this.dataset);
        fit.transform(this.dataset).createOrReplaceTempView("prediction");
        this.spark.sql("SELECT label, probability, prediction FROM prediction").collectAsList();
        Assert.assertEquals(0.5d, fit.getThreshold(), this.eps);
        Assert.assertEquals("features", fit.getFeaturesCol());
        Assert.assertEquals("prediction", fit.getPredictionCol());
        Assert.assertEquals("probability", fit.getProbabilityCol());
    }

    @Test
    public void logisticRegressionWithSetters() {
        LogisticRegression probabilityCol = new LogisticRegression().setMaxIter(10).setRegParam(1.0d).setThreshold(0.6d).setProbabilityCol("myProbability");
        LogisticRegressionModel fit = probabilityCol.fit(this.dataset);
        LogisticRegression parent = fit.parent();
        Assert.assertEquals(10L, parent.getMaxIter());
        Assert.assertEquals(1.0d, parent.getRegParam(), this.eps);
        Assert.assertEquals(0.4d, parent.getThresholds()[0], this.eps);
        Assert.assertEquals(0.6d, parent.getThresholds()[1], this.eps);
        Assert.assertEquals(0.6d, parent.getThreshold(), this.eps);
        Assert.assertEquals(0.6d, fit.getThreshold(), this.eps);
        fit.setThreshold(1.0d);
        fit.transform(this.dataset).createOrReplaceTempView("predAllZero");
        Iterator it = this.spark.sql("SELECT prediction, myProbability FROM predAllZero").collectAsList().iterator();
        while (it.hasNext()) {
            Assert.assertEquals(0.0d, ((Row) it.next()).getDouble(0), this.eps);
        }
        fit.transform(this.dataset, fit.threshold().w(0.0d), new ParamPair[]{fit.probabilityCol().w("myProb")}).createOrReplaceTempView("predNotAllZero");
        boolean z = false;
        Iterator it2 = this.spark.sql("SELECT prediction, myProb FROM predNotAllZero").collectAsList().iterator();
        while (it2.hasNext()) {
            if (((Row) it2.next()).getDouble(0) != 0.0d) {
                z = true;
            }
        }
        Assert.assertTrue(z);
        LogisticRegressionModel fit2 = probabilityCol.fit(this.dataset, probabilityCol.maxIter().w(5), new ParamPair[]{probabilityCol.regParam().w(0.1d), probabilityCol.threshold().w(0.4d), probabilityCol.probabilityCol().w("theProb")});
        LogisticRegression parent2 = fit2.parent();
        Assert.assertEquals(5L, parent2.getMaxIter());
        Assert.assertEquals(0.1d, parent2.getRegParam(), this.eps);
        Assert.assertEquals(0.4d, parent2.getThreshold(), this.eps);
        Assert.assertEquals(0.4d, fit2.getThreshold(), this.eps);
        Assert.assertEquals("theProb", fit2.getProbabilityCol());
    }

    @Test
    public void logisticRegressionPredictorClassifierMethods() {
        LogisticRegressionModel fit = new LogisticRegression().fit(this.dataset);
        Assert.assertEquals(2L, fit.numClasses());
        fit.transform(this.dataset).createOrReplaceTempView("transformed");
        for (Row row : this.spark.sql("SELECT rawPrediction, probability FROM transformed").collectAsList()) {
            Vector vector = (Vector) row.get(0);
            Vector vector2 = (Vector) row.get(1);
            Assert.assertEquals(vector.size(), 2L);
            Assert.assertEquals(vector2.size(), 2L);
            double exp = 1.0d / (1.0d + Math.exp(-vector.apply(1)));
            Assert.assertEquals(0.0d, Math.abs(vector2.apply(1) - exp), this.eps);
            Assert.assertEquals(0.0d, Math.abs(vector2.apply(0) - (1.0d - exp)), this.eps);
        }
        for (Row row2 : this.spark.sql("SELECT prediction, probability FROM transformed").collectAsList()) {
            double d = row2.getDouble(0);
            Vector vector3 = (Vector) row2.get(1);
            double apply = vector3.apply((int) d);
            for (int i = 0; i < vector3.size(); i++) {
                Assert.assertTrue(apply >= vector3.apply(i));
            }
        }
    }

    @Test
    public void logisticRegressionTrainingSummary() {
        LogisticRegressionTrainingSummary summary = new LogisticRegression().fit(this.dataset).summary();
        Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
    }
}
