package org.apache.spark.mllib.regression;

import java.io.Serializable;
import java.util.List;
import java.util.Random;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.sql.SparkSession;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.class */
public class JavaRidgeRegressionSuite implements Serializable {
    private transient SparkSession spark;
    private transient JavaSparkContext jsc;

    @Before
    public void setUp() {
        this.spark = SparkSession.builder().master("local").appName("JavaRidgeRegressionSuite").getOrCreate();
        this.jsc = new JavaSparkContext(this.spark.sparkContext());
    }

    @After
    public void tearDown() {
        this.spark.stop();
        this.spark = null;
    }

    private static double predictionError(List<LabeledPoint> list, RidgeRegressionModel ridgeRegressionModel) {
        double d = 0.0d;
        for (LabeledPoint labeledPoint : list) {
            Double valueOf = Double.valueOf(ridgeRegressionModel.predict(labeledPoint.features()));
            d += (valueOf.doubleValue() - labeledPoint.label()) * (valueOf.doubleValue() - labeledPoint.label());
        }
        return d / list.size();
    }

    private static List<LabeledPoint> generateRidgeData(int i, int i2, double d) {
        Random random = new Random(42L);
        double[] dArr = new double[i2];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = random.nextDouble() - 0.5d;
        }
        return LinearDataGenerator.generateLinearInputAsList(0.0d, dArr, i, 42, d);
    }

    @Test
    public void runRidgeRegressionUsingConstructor() {
        List<LabeledPoint> generateRidgeData = generateRidgeData(2 * 50, 20, 10.0d);
        JavaRDD parallelize = this.jsc.parallelize(generateRidgeData.subList(0, 50));
        List<LabeledPoint> subList = generateRidgeData.subList(50, 2 * 50);
        RidgeRegressionWithSGD ridgeRegressionWithSGD = new RidgeRegressionWithSGD();
        ridgeRegressionWithSGD.optimizer().setStepSize(1.0d).setRegParam(0.0d).setNumIterations(200);
        double predictionError = predictionError(subList, ridgeRegressionWithSGD.run(parallelize.rdd()));
        ridgeRegressionWithSGD.optimizer().setRegParam(0.1d);
        Assert.assertTrue(predictionError(subList, ridgeRegressionWithSGD.run(parallelize.rdd())) < predictionError);
    }

    @Test
    public void runRidgeRegressionUsingStaticMethods() {
        List<LabeledPoint> generateRidgeData = generateRidgeData(2 * 50, 20, 10.0d);
        JavaRDD parallelize = this.jsc.parallelize(generateRidgeData.subList(0, 50));
        List<LabeledPoint> subList = generateRidgeData.subList(50, 2 * 50);
        Assert.assertTrue(predictionError(subList, RidgeRegressionWithSGD.train(parallelize.rdd(), 200, 1.0d, 0.1d)) < predictionError(subList, RidgeRegressionWithSGD.train(parallelize.rdd(), 200, 1.0d, 0.0d)));
    }
}
