package org.apache.spark.mllib.regression;

import java.io.File;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.util.TempDirectory;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$testImplicits$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.ArrayImplicits$;
import org.apache.spark.util.Utils$;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalactic.Prettifier$;
import org.scalactic.source.Position;
import org.scalatest.Assertions;
import org.scalatest.Assertions$;
import scala.Array$;
import scala.MatchError;
import scala.Tuple2;
import scala.collection.IterableOnceOps;
import scala.collection.IterableOps;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.math.Numeric$DoubleIsFractional$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.util.Random;

/* compiled from: RidgeRegressionSuite.scala */
@ScalaSignature(bytes = "\u0006\u0005u;Q!\u0003\u0006\t\nU1Qa\u0006\u0006\t\naAQaJ\u0001\u0005\u0002!Bq!K\u0001C\u0002\u0013\u0005!\u0006\u0003\u0004/\u0003\u0001\u0006Ia\u000b\u0005\b_\u0005\t\t\u0011\"\u00031\r\u00119\"\u0002A\u001c\t\u000b\u001d2A\u0011\u0001\"\t\u000b\u00113A\u0011A#\u0002)IKGmZ3SK\u001e\u0014Xm]:j_:\u001cV/\u001b;f\u0015\tYA\"\u0001\u0006sK\u001e\u0014Xm]:j_:T!!\u0004\b\u0002\u000b5dG.\u001b2\u000b\u0005=\u0001\u0012!B:qCJ\\'BA\t\u0013\u0003\u0019\t\u0007/Y2iK*\t1#A\u0002pe\u001e\u001c\u0001\u0001\u0005\u0002\u0017\u00035\t!B\u0001\u000bSS\u0012<WMU3he\u0016\u001c8/[8o'VLG/Z\n\u0004\u0003ey\u0002C\u0001\u000e\u001e\u001b\u0005Y\"\"\u0001\u000f\u0002\u000bM\u001c\u0017\r\\1\n\u0005yY\"AB!osJ+g\r\u0005\u0002!K5\t\u0011E\u0003\u0002#G\u0005\u0011\u0011n\u001c\u0006\u0002I\u0005!!.\u0019<b\u0013\t1\u0013E\u0001\u0007TKJL\u0017\r\\5{C\ndW-\u0001\u0004=S:LGO\u0010\u000b\u0002+\u0005)Qn\u001c3fYV\t1\u0006\u0005\u0002\u0017Y%\u0011QF\u0003\u0002\u0015%&$w-\u001a*fOJ,7o]5p]6{G-\u001a7\u0002\r5|G-\u001a7!\u000319(/\u001b;f%\u0016\u0004H.Y2f)\u0005\t\u0004C\u0001\u001a6\u001b\u0005\u0019$B\u0001\u001b$\u0003\u0011a\u0017M\\4\n\u0005Y\u001a$AB(cU\u0016\u001cGoE\u0002\u0007qq\u0002\"!\u000f\u001e\u000e\u00039I!a\u000f\b\u0003\u001bM\u0003\u0018M]6Gk:\u001cV/\u001b;f!\ti\u0004)D\u0001?\u0015\tyD\"\u0001\u0003vi&d\u0017BA!?\u0005UiE\n\\5c)\u0016\u001cHo\u00159be.\u001cuN\u001c;fqR$\u0012a\u0011\t\u0003-\u0019\tq\u0002\u001d:fI&\u001cG/[8o\u000bJ\u0014xN\u001d\u000b\u0004\r&;\u0006C\u0001\u000eH\u0013\tA5D\u0001\u0004E_V\u0014G.\u001a\u0005\u0006\u0015\"\u0001\raS\u0001\faJ,G-[2uS>t7\u000fE\u0002M)\u001as!!\u0014*\u000f\u00059\u000bV\"A(\u000b\u0005A#\u0012A\u0002\u001fs_>$h(C\u0001\u001d\u0013\t\u00196$A\u0004qC\u000e\\\u0017mZ3\n\u0005U3&aA*fc*\u00111k\u0007\u0005\u00061\"\u0001\r!W\u0001\u0006S:\u0004X\u000f\u001e\t\u0004\u0019RS\u0006C\u0001\f\\\u0013\ta&B\u0001\u0007MC\n,G.\u001a3Q_&tG\u000f")
/* loaded from: input_file:org/apache/spark/mllib/regression/RidgeRegressionSuite.class */
public class RidgeRegressionSuite extends SparkFunSuite implements MLlibTestSparkContext {
    private transient SparkSession spark;
    private transient SparkContext sc;
    private transient String checkpointDir;
    private volatile MLlibTestSparkContext$testImplicits$ testImplicits$module;
    private File org$apache$spark$ml$util$TempDirectory$$_tempDir;

    public static RidgeRegressionModel model() {
        return RidgeRegressionSuite$.MODULE$.model();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        beforeAll();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        afterAll();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext, org.apache.spark.ml.util.TempDirectory
    public void beforeAll() {
        beforeAll();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext, org.apache.spark.ml.util.TempDirectory
    public void afterAll() {
        afterAll();
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public Instance[] standardize(Instance[] instanceArr) {
        Instance[] standardize;
        standardize = standardize(instanceArr);
        return standardize;
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$beforeAll() {
        super.beforeAll();
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$afterAll() {
        super.afterAll();
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public File tempDir() {
        File tempDir;
        tempDir = tempDir();
        return tempDir;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public SparkSession spark() {
        return this.spark;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void spark_$eq(SparkSession sparkSession) {
        this.spark = sparkSession;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public SparkContext sc() {
        return this.sc;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void sc_$eq(SparkContext sparkContext) {
        this.sc = sparkContext;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public String checkpointDir() {
        return this.checkpointDir;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void checkpointDir_$eq(String str) {
        this.checkpointDir = str;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public MLlibTestSparkContext$testImplicits$ testImplicits() {
        if (this.testImplicits$module == null) {
            testImplicits$lzycompute$1();
        }
        return this.testImplicits$module;
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public File org$apache$spark$ml$util$TempDirectory$$_tempDir() {
        return this.org$apache$spark$ml$util$TempDirectory$$_tempDir;
    }

    @Override // org.apache.spark.ml.util.TempDirectory
    public void org$apache$spark$ml$util$TempDirectory$$_tempDir_$eq(File file) {
        this.org$apache$spark$ml$util$TempDirectory$$_tempDir = file;
    }

    public double predictionError(Seq<Object> seq, Seq<LabeledPoint> seq2) {
        return BoxesRunTime.unboxToDouble(((IterableOnceOps) ((IterableOps) seq.zip(seq2)).map(tuple2 -> {
            return BoxesRunTime.boxToDouble($anonfun$predictionError$1(tuple2));
        })).sum(Numeric$DoubleIsFractional$.MODULE$)) / seq.size();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5, types: [org.apache.spark.mllib.regression.RidgeRegressionSuite] */
    private final void testImplicits$lzycompute$1() {
        ?? r0 = this;
        synchronized (r0) {
            if (this.testImplicits$module == null) {
                r0 = this;
                r0.testImplicits$module = new MLlibTestSparkContext$testImplicits$(this);
            }
        }
    }

    public static final /* synthetic */ double $anonfun$predictionError$1(Tuple2 tuple2) {
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        double _1$mcD$sp = tuple2._1$mcD$sp();
        LabeledPoint labeledPoint = (LabeledPoint) tuple2._2();
        return (_1$mcD$sp - labeledPoint.label()) * (_1$mcD$sp - labeledPoint.label());
    }

    public RidgeRegressionSuite() {
        TempDirectory.$init$(this);
        MLlibTestSparkContext.$init$((MLlibTestSparkContext) this);
        test("ridge regression can help avoid overfitting", Nil$.MODULE$, () -> {
            Random random = new Random(42);
            Seq generateLinearInput = LinearDataGenerator$.MODULE$.generateLinearInput(3.0d, (double[]) Array$.MODULE$.fill(20, () -> {
                return random.nextDouble() - 0.5d;
            }, ClassTag$.MODULE$.Double()), 2 * 50, 42, 10.0d);
            Seq seq = (Seq) generateLinearInput.take(50);
            Seq<LabeledPoint> seq2 = (Seq) generateLinearInput.takeRight(50);
            RDD cache = this.sc().parallelize(seq, 2, ClassTag$.MODULE$.apply(LabeledPoint.class)).cache();
            RDD cache2 = this.sc().parallelize(seq2, 2, ClassTag$.MODULE$.apply(LabeledPoint.class)).cache();
            double predictionError = this.predictionError(ArrayImplicits$.MODULE$.SparkArrayOps(new LinearRegressionWithSGD(1.0d, 200, 0.0d, 1.0d).run(cache).predict(cache2.map(labeledPoint -> {
                return labeledPoint.features();
            }, ClassTag$.MODULE$.apply(Vector.class))).collect()).toImmutableArraySeq(), seq2);
            double predictionError2 = this.predictionError(ArrayImplicits$.MODULE$.SparkArrayOps(new RidgeRegressionWithSGD(1.0d, 200, 0.1d, 1.0d).run(cache).predict(cache2.map(labeledPoint2 -> {
                return labeledPoint2.features();
            }, ClassTag$.MODULE$.apply(Vector.class))).collect()).toImmutableArraySeq(), seq2);
            Bool binaryMacroBool = Bool$.MODULE$.binaryMacroBool(BoxesRunTime.boxToDouble(predictionError2), "<", BoxesRunTime.boxToDouble(predictionError), predictionError2 < predictionError, Prettifier$.MODULE$.default());
            Assertions.AssertionsHelper assertionsHelper = Assertions$.MODULE$.assertionsHelper();
            return assertionsHelper.macroAssert(binaryMacroBool, "ridgeError (" + predictionError2 + ") was not less than linearError(" + assertionsHelper + ")", Prettifier$.MODULE$.default(), new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 78));
        }, new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 43));
        test("model save/load", Nil$.MODULE$, () -> {
            boolean z;
            RidgeRegressionModel model = RidgeRegressionSuite$.MODULE$.model();
            File createTempDir = Utils$.MODULE$.createTempDir();
            String uri = createTempDir.toURI().toString();
            try {
                model.save(this.sc(), uri);
                RidgeRegressionModel load = RidgeRegressionModel$.MODULE$.load(this.sc(), uri);
                Vector weights = model.weights();
                Vector weights2 = load.weights();
                Bool$ bool$ = Bool$.MODULE$;
                if (weights == null) {
                    z = weights2 == null;
                }
                Assertions$.MODULE$.assertionsHelper().macroAssert(bool$.binaryMacroBool(weights, "==", weights2, z, Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 92));
                double intercept = model.intercept();
                double intercept2 = load.intercept();
                return Assertions$.MODULE$.assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(BoxesRunTime.boxToDouble(intercept), "==", BoxesRunTime.boxToDouble(intercept2), intercept == intercept2, Prettifier$.MODULE$.default()), "", Prettifier$.MODULE$.default(), new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 93));
            } finally {
                Utils$.MODULE$.deleteRecursively(createTempDir);
            }
        }, new Position("RidgeRegressionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 82));
    }
}
