package org.apache.spark.mllib.regression;

import java.io.File;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.util.TempDirectory;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$testImplicits$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils$;
import org.scalactic.Bool$;
import org.scalactic.Prettifier$;
import org.scalactic.source.Position;
import org.scalatest.Tag;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.math.Numeric$DoubleIsFractional$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.util.Random;

/* compiled from: RidgeRegressionSuite.scala */
@ScalaSignature(bytes = "\u0006\u0001i;Q!\u0001\u0002\t\n5\tACU5eO\u0016\u0014Vm\u001a:fgNLwN\\*vSR,'BA\u0002\u0005\u0003)\u0011Xm\u001a:fgNLwN\u001c\u0006\u0003\u000b\u0019\tQ!\u001c7mS\nT!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001!\tqq\"D\u0001\u0003\r\u0015\u0001\"\u0001#\u0003\u0012\u0005Q\u0011\u0016\u000eZ4f%\u0016<'/Z:tS>t7+^5uKN\u0019qB\u0005\r\u0011\u0005M1R\"\u0001\u000b\u000b\u0003U\tQa]2bY\u0006L!a\u0006\u000b\u0003\r\u0005s\u0017PU3g!\t\u0019\u0012$\u0003\u0002\u001b)\ta1+\u001a:jC2L'0\u00192mK\")Ad\u0004C\u0001;\u00051A(\u001b8jiz\"\u0012!\u0004\u0005\b?=\u0011\r\u0011\"\u0001!\u0003\u0015iw\u000eZ3m+\u0005\t\u0003C\u0001\b#\u0013\t\u0019#A\u0001\u000bSS\u0012<WMU3he\u0016\u001c8/[8o\u001b>$W\r\u001c\u0005\u0007K=\u0001\u000b\u0011B\u0011\u0002\r5|G-\u001a7!\u0011\u001d9s\"!A\u0005\n!\n1B]3bIJ+7o\u001c7wKR\t\u0011\u0006\u0005\u0002+_5\t1F\u0003\u0002-[\u0005!A.\u00198h\u0015\u0005q\u0013\u0001\u00026bm\u0006L!\u0001M\u0016\u0003\r=\u0013'.Z2u\r\u0011\u0001\"\u0001\u0001\u001a\u0014\u0007E\u001at\u0007\u0005\u00025k5\ta!\u0003\u00027\r\ti1\u000b]1sW\u001a+hnU;ji\u0016\u0004\"\u0001O\u001e\u000e\u0003eR!A\u000f\u0003\u0002\tU$\u0018\u000e\\\u0005\u0003ye\u0012Q#\u0014'mS\n$Vm\u001d;Ta\u0006\u00148nQ8oi\u0016DH\u000fC\u0003\u001dc\u0011\u0005a\bF\u0001@!\tq\u0011\u0007C\u0003Bc\u0011\u0005!)A\bqe\u0016$\u0017n\u0019;j_:,%O]8s)\r\u0019e\t\u0016\t\u0003'\u0011K!!\u0012\u000b\u0003\r\u0011{WO\u00197f\u0011\u00159\u0005\t1\u0001I\u0003-\u0001(/\u001a3jGRLwN\\:\u0011\u0007%\u000b6I\u0004\u0002K\u001f:\u00111JT\u0007\u0002\u0019*\u0011Q\nD\u0001\u0007yI|w\u000e\u001e \n\u0003UI!\u0001\u0015\u000b\u0002\u000fA\f7m[1hK&\u0011!k\u0015\u0002\u0004'\u0016\f(B\u0001)\u0015\u0011\u0015)\u0006\t1\u0001W\u0003\u0015Ig\u000e];u!\rI\u0015k\u0016\t\u0003\u001daK!!\u0017\u0002\u0003\u00191\u000b'-\u001a7fIB{\u0017N\u001c;")
/* loaded from: input_file:org/apache/spark/mllib/regression/RidgeRegressionSuite.class */
public class RidgeRegressionSuite extends SparkFunSuite implements MLlibTestSparkContext {
    private transient SparkSession spark;
    private transient SparkContext sc;
    private transient String checkpointDir;
    private volatile MLlibTestSparkContext$testImplicits$ testImplicits$module;
    private File org$apache$spark$ml$util$TempDirectory$$_tempDir;

    public static RidgeRegressionModel model() {
        return RidgeRegressionSuite$.MODULE$.model();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        beforeAll();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        afterAll();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext, org.apache.spark.ml.util.TempDirectory
    public void beforeAll() {
        beforeAll();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext, org.apache.spark.ml.util.TempDirectory
    public void afterAll() {
        afterAll();
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$beforeAll() {
        super.beforeAll();
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$afterAll() {
        super.afterAll();
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public File tempDir() {
        File tempDir;
        tempDir = tempDir();
        return tempDir;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public SparkSession spark() {
        return this.spark;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void spark_$eq(SparkSession sparkSession) {
        this.spark = sparkSession;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public SparkContext sc() {
        return this.sc;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void sc_$eq(SparkContext sparkContext) {
        this.sc = sparkContext;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public String checkpointDir() {
        return this.checkpointDir;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void checkpointDir_$eq(String str) {
        this.checkpointDir = str;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public MLlibTestSparkContext$testImplicits$ testImplicits() {
        if (this.testImplicits$module == null) {
            testImplicits$lzycompute$1();
        }
        return this.testImplicits$module;
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public File org$apache$spark$ml$util$TempDirectory$$_tempDir() {
        return this.org$apache$spark$ml$util$TempDirectory$$_tempDir;
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public void org$apache$spark$ml$util$TempDirectory$$_tempDir_$eq(File file) {
        this.org$apache$spark$ml$util$TempDirectory$$_tempDir = file;
    }

    public double predictionError(Seq<Object> seq, Seq<LabeledPoint> seq2) {
        return BoxesRunTime.unboxToDouble(((TraversableOnce) ((TraversableLike) seq.zip(seq2, Seq$.MODULE$.canBuildFrom())).map(tuple2 -> {
            return BoxesRunTime.boxToDouble($anonfun$predictionError$1(tuple2));
        }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$)) / seq.size();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5, types: [org.apache.spark.mllib.regression.RidgeRegressionSuite] */
    private final void testImplicits$lzycompute$1() {
        ?? r0 = this;
        synchronized (r0) {
            if (this.testImplicits$module == null) {
                r0 = this;
                r0.testImplicits$module = new MLlibTestSparkContext$testImplicits$(this);
            }
        }
    }

    public static final /* synthetic */ double $anonfun$predictionError$1(Tuple2 tuple2) {
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        double _1$mcD$sp = tuple2._1$mcD$sp();
        LabeledPoint labeledPoint = (LabeledPoint) tuple2._2();
        return (_1$mcD$sp - labeledPoint.label()) * (_1$mcD$sp - labeledPoint.label());
    }

    public RidgeRegressionSuite() {
        TempDirectory.$init$(this);
        MLlibTestSparkContext.$init$((MLlibTestSparkContext) this);
        test("ridge regression can help avoid overfitting", Predef$.MODULE$.wrapRefArray(new Tag[0]), () -> {
            Random random = new Random(42);
            Seq generateLinearInput = LinearDataGenerator$.MODULE$.generateLinearInput(3.0d, (double[]) Array$.MODULE$.fill(20, () -> {
                return random.nextDouble() - 0.5d;
            }, ClassTag$.MODULE$.Double()), 2 * 50, 42, 10.0d);
            Seq seq = (Seq) generateLinearInput.take(50);
            Seq<LabeledPoint> seq2 = (Seq) generateLinearInput.takeRight(50);
            RDD cache = this.sc().parallelize(seq, 2, ClassTag$.MODULE$.apply(LabeledPoint.class)).cache();
            RDD cache2 = this.sc().parallelize(seq2, 2, ClassTag$.MODULE$.apply(LabeledPoint.class)).cache();
            LinearRegressionWithSGD linearRegressionWithSGD = new LinearRegressionWithSGD();
            linearRegressionWithSGD.optimizer().setNumIterations(200).setStepSize(1.0d);
            double predictionError = this.predictionError(Predef$.MODULE$.wrapDoubleArray((double[]) linearRegressionWithSGD.run(cache).predict(cache2.map(labeledPoint -> {
                return labeledPoint.features();
            }, ClassTag$.MODULE$.apply(Vector.class))).collect()), seq2);
            RidgeRegressionWithSGD ridgeRegressionWithSGD = new RidgeRegressionWithSGD();
            ridgeRegressionWithSGD.optimizer().setNumIterations(200).setRegParam(0.1d).setStepSize(1.0d);
            double predictionError2 = this.predictionError(Predef$.MODULE$.wrapDoubleArray((double[]) ridgeRegressionWithSGD.run(cache).predict(cache2.map(labeledPoint2 -> {
                return labeledPoint2.features();
            }, ClassTag$.MODULE$.apply(Vector.class))).collect()), seq2);
            return this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(BoxesRunTime.boxToDouble(predictionError2), "<", BoxesRunTime.boxToDouble(predictionError), predictionError2 < predictionError, Prettifier$.MODULE$.default()), new StringBuilder(45).append("ridgeError (").append(predictionError2).append(") was not less than linearError(").append(predictionError).append(")").toString(), Prettifier$.MODULE$.default(), new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 80));
        }, new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 42));
        test("model save/load", Predef$.MODULE$.wrapRefArray(new Tag[0]), () -> {
            boolean z;
            RidgeRegressionModel model = RidgeRegressionSuite$.MODULE$.model();
            File createTempDir = Utils$.MODULE$.createTempDir(Utils$.MODULE$.createTempDir$default$1(), Utils$.MODULE$.createTempDir$default$2());
            String uri = createTempDir.toURI().toString();
            try {
                model.save(this.sc(), uri);
                RidgeRegressionModel load = RidgeRegressionModel$.MODULE$.load(this.sc(), uri);
                Vector weights = model.weights();
                Vector weights2 = load.weights();
                Bool$ bool$ = Bool$.MODULE$;
                if (weights == null) {
                    z = weights2 == null;
                }
                this.assertionsHelper().macroAssert(bool$.binaryMacroBool(weights, "==", weights2, z, Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 94));
                double intercept = model.intercept();
                double intercept2 = load.intercept();
                return this.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(BoxesRunTime.boxToDouble(intercept), "==", BoxesRunTime.boxToDouble(intercept2), intercept == intercept2, Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 95));
            } finally {
                Utils$.MODULE$.deleteRecursively(createTempDir);
            }
        }, new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 84));
    }
}
