package org.apache.spark.ml.classification;

import breeze.linalg.Vector$;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.TestingUtils$;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.scalactic.Bool$;
import org.scalatest.Args;
import org.scalatest.BeforeAndAfterAll;
import org.scalatest.ConfigMap;
import org.scalatest.FunSuiteLike;
import org.scalatest.Status;
import org.scalatest.Tag;
import scala.Array$;
import scala.Option;
import scala.Predef$;
import scala.math.Numeric$DoubleIsFractional$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: NaiveBayesSuite.scala */
@ScalaSignature(bytes = "\u0006\u0001}3A!\u0001\u0002\u0001\u001b\tya*Y5wK\n\u000b\u00170Z:Tk&$XM\u0003\u0002\u0004\t\u0005q1\r\\1tg&4\u0017nY1uS>t'BA\u0003\u0007\u0003\tiGN\u0003\u0002\b\u0011\u0005)1\u000f]1sW*\u0011\u0011BC\u0001\u0007CB\f7\r[3\u000b\u0003-\t1a\u001c:h\u0007\u0001\u00192\u0001\u0001\b\u0013!\ty\u0001#D\u0001\u0007\u0013\t\tbAA\u0007Ta\u0006\u00148NR;o'VLG/\u001a\t\u0003'ai\u0011\u0001\u0006\u0006\u0003+Y\tA!\u001e;jY*\u0011qCB\u0001\u0006[2d\u0017NY\u0005\u00033Q\u0011Q#\u0014'mS\n$Vm\u001d;Ta\u0006\u00148nQ8oi\u0016DH\u000fC\u0003\u001c\u0001\u0011\u0005A$\u0001\u0004=S:LGO\u0010\u000b\u0002;A\u0011a\u0004A\u0007\u0002\u0005!)\u0001\u0005\u0001C\u0001C\u0005\u0011b/\u00197jI\u0006$X\r\u0015:fI&\u001cG/[8o)\t\u0011\u0003\u0006\u0005\u0002$M5\tAEC\u0001&\u0003\u0015\u00198-\u00197b\u0013\t9CE\u0001\u0003V]&$\b\"B\u0015 \u0001\u0004Q\u0013a\u00059sK\u0012L7\r^5p]\u0006sG\rT1cK2\u001c\bCA\u0016/\u001b\u0005a#BA\u0017\u0007\u0003\r\u0019\u0018\u000f\\\u0005\u0003_1\u0012\u0011\u0002R1uC\u001a\u0013\u0018-\\3\t\u000bE\u0002A\u0011\u0001\u001a\u0002!Y\fG.\u001b3bi\u0016lu\u000eZ3m\r&$H\u0003\u0002\u00124w\u0001CQ\u0001\u000e\u0019A\u0002U\na\u0001]5ECR\f\u0007C\u0001\u001c:\u001b\u00059$B\u0001\u001d\u0017\u0003\u0019a\u0017N\\1mO&\u0011!h\u000e\u0002\u0007-\u0016\u001cGo\u001c:\t\u000bq\u0002\u0004\u0019A\u001f\u0002\u0013QDW\r^1ECR\f\u0007C\u0001\u001c?\u0013\tytG\u0001\u0004NCR\u0014\u0018\u000e\u001f\u0005\u0006\u0003B\u0002\rAQ\u0001\u0006[>$W\r\u001c\t\u0003=\rK!\u0001\u0012\u0002\u0003\u001f9\u000b\u0017N^3CCf,7/T8eK2DQA\u0012\u0001\u0005\u0002\u001d\u000b\u0001%\u001a=qK\u000e$X\rZ'vYRLgn\\7jC2\u0004&o\u001c2bE&d\u0017\u000e^5fgR\u0019Q\u0007S%\t\u000b\u0005+\u0005\u0019\u0001\"\t\u000b)+\u0005\u0019A\u001b\u0002\u000f\u0019,\u0017\r^;sK\")A\n\u0001C\u0001\u001b\u0006qR\r\u001f9fGR,GMQ3s]>,H\u000e\\5Qe>\u0014\u0017MY5mSRLWm\u001d\u000b\u0004k9{\u0005\"B!L\u0001\u0004\u0011\u0005\"\u0002&L\u0001\u0004)\u0004\"B)\u0001\t\u0003\u0011\u0016!\u0006<bY&$\u0017\r^3Qe>\u0014\u0017MY5mSRLWm\u001d\u000b\u0005EM+f\u000bC\u0003U!\u0002\u0007!&A\fgK\u0006$XO]3B]\u0012\u0004&o\u001c2bE&d\u0017\u000e^5fg\")\u0011\t\u0015a\u0001\u0005\")q\u000b\u0015a\u00011\u0006IQn\u001c3fYRK\b/\u001a\t\u00033rs!a\t.\n\u0005m#\u0013A\u0002)sK\u0012,g-\u0003\u0002^=\n11\u000b\u001e:j]\u001eT!a\u0017\u0013")
/* loaded from: input_file:org/apache/spark/ml/classification/NaiveBayesSuite.class */
public class NaiveBayesSuite extends SparkFunSuite implements MLlibTestSparkContext {
    private transient SparkContext sc;
    private transient SQLContext sqlContext;
    private final boolean invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected;

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public SparkContext sc() {
        return this.sc;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void sc_$eq(SparkContext sparkContext) {
        this.sc = sparkContext;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public SQLContext sqlContext() {
        return this.sqlContext;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void sqlContext_$eq(SQLContext sQLContext) {
        this.sqlContext = sQLContext;
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        BeforeAndAfterAll.class.beforeAll(this);
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        BeforeAndAfterAll.class.afterAll(this);
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void beforeAll() {
        MLlibTestSparkContext.Cclass.beforeAll(this);
    }

    @Override // org.apache.spark.mllib.util.MLlibTestSparkContext
    public void afterAll() {
        MLlibTestSparkContext.Cclass.afterAll(this);
    }

    public boolean invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected() {
        return this.invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected;
    }

    public /* synthetic */ Status org$scalatest$BeforeAndAfterAll$$super$run(Option option, Args args) {
        return FunSuiteLike.class.run(this, option, args);
    }

    public void org$scalatest$BeforeAndAfterAll$_setter_$invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected_$eq(boolean z) {
        this.invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected = z;
    }

    public void beforeAll(ConfigMap configMap) {
        BeforeAndAfterAll.class.beforeAll(this, configMap);
    }

    public void afterAll(ConfigMap configMap) {
        BeforeAndAfterAll.class.afterAll(this, configMap);
    }

    public Status run(Option<String> option, Args args) {
        return BeforeAndAfterAll.class.run(this, option, args);
    }

    public void validatePrediction(DataFrame dataFrame) {
        int count = Predef$.MODULE$.refArrayOps(dataFrame.collect()).count(new NaiveBayesSuite$$anonfun$15(this));
        long count2 = dataFrame.count() / 5;
        assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(BoxesRunTime.boxToInteger(count), "<", BoxesRunTime.boxToLong(count2), ((long) count) < count2), "");
    }

    public void validateModelFit(Vector vector, Matrix matrix, NaiveBayesModel naiveBayesModel) {
        assertionsHelper().macroAssert(Bool$.MODULE$.simpleMacroBool(TestingUtils$.MODULE$.VectorWithAlmostEquals(Vectors$.MODULE$.dense((double[]) Predef$.MODULE$.doubleArrayOps(naiveBayesModel.pi().toArray()).map(new NaiveBayesSuite$$anonfun$1(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())))).$tilde$eq$eq(TestingUtils$.MODULE$.VectorWithAlmostEquals(Vectors$.MODULE$.dense((double[]) Predef$.MODULE$.doubleArrayOps(vector.toArray()).map(new NaiveBayesSuite$$anonfun$2(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())))).absTol(0.05d)), "org.apache.spark.mllib.util.TestingUtils.VectorWithAlmostEquals(org.apache.spark.mllib.linalg.Vectors.dense(scala.this.Predef.doubleArrayOps(model.pi.toArray).map[Double, Array[Double]]({\n  ((x: Double) => scala.math.`package`.exp(x))\n})(scala.this.Array.canBuildFrom[Double]((ClassTag.Double: scala.reflect.ClassTag[Double]))))).~==(org.apache.spark.mllib.util.TestingUtils.VectorWithAlmostEquals(org.apache.spark.mllib.linalg.Vectors.dense(scala.this.Predef.doubleArrayOps(piData.toArray).map[Double, Array[Double]]({\n  ((x: Double) => scala.math.`package`.exp(x))\n})(scala.this.Array.canBuildFrom[Double]((ClassTag.Double: scala.reflect.ClassTag[Double]))))).absTol(0.05))"), "pi mismatch");
        assertionsHelper().macroAssert(Bool$.MODULE$.simpleMacroBool(TestingUtils$.MODULE$.MatrixWithAlmostEquals(naiveBayesModel.theta().map(new NaiveBayesSuite$$anonfun$3(this))).$tilde$eq$eq(TestingUtils$.MODULE$.MatrixWithAlmostEquals(matrix.map(new NaiveBayesSuite$$anonfun$4(this))).absTol(0.05d)), "org.apache.spark.mllib.util.TestingUtils.MatrixWithAlmostEquals(model.theta.map({\n  ((x: Double) => scala.math.`package`.exp(x))\n})).~==(org.apache.spark.mllib.util.TestingUtils.MatrixWithAlmostEquals(thetaData.map({\n  ((x: Double) => scala.math.`package`.exp(x))\n})).absTol(0.05))"), "theta mismatch");
    }

    public Vector expectedMultinomialProbabilities(NaiveBayesModel naiveBayesModel, Vector vector) {
        double[] dArr = (double[]) Predef$.MODULE$.doubleArrayOps(((breeze.linalg.Vector) naiveBayesModel.pi().toBreeze().$plus(naiveBayesModel.theta().multiply(vector).toBreeze(), Vector$.MODULE$.v_v_Idempotent_Op_Double_OpAdd())).toArray$mcD$sp(ClassTag$.MODULE$.Double())).map(new NaiveBayesSuite$$anonfun$5(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        return Vectors$.MODULE$.dense((double[]) Predef$.MODULE$.doubleArrayOps(dArr).map(new NaiveBayesSuite$$anonfun$expectedMultinomialProbabilities$1(this, BoxesRunTime.unboxToDouble(Predef$.MODULE$.doubleArrayOps(dArr).sum(Numeric$DoubleIsFractional$.MODULE$))), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())));
    }

    public Vector expectedBernoulliProbabilities(NaiveBayesModel naiveBayesModel, Vector vector) {
        double[] dArr = (double[]) Predef$.MODULE$.doubleArrayOps(((breeze.linalg.Vector) ((breeze.linalg.Vector) naiveBayesModel.pi().toBreeze().$plus(naiveBayesModel.theta().multiply(vector).toBreeze(), Vector$.MODULE$.v_v_Idempotent_Op_Double_OpAdd())).$plus(naiveBayesModel.theta().map(new NaiveBayesSuite$$anonfun$6(this)).multiply(Vectors$.MODULE$.dense((double[]) Predef$.MODULE$.doubleArrayOps(vector.toArray()).map(new NaiveBayesSuite$$anonfun$7(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())))).toBreeze(), Vector$.MODULE$.v_v_Idempotent_Op_Double_OpAdd())).toArray$mcD$sp(ClassTag$.MODULE$.Double())).map(new NaiveBayesSuite$$anonfun$8(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        return Vectors$.MODULE$.dense((double[]) Predef$.MODULE$.doubleArrayOps(dArr).map(new NaiveBayesSuite$$anonfun$expectedBernoulliProbabilities$1(this, BoxesRunTime.unboxToDouble(Predef$.MODULE$.doubleArrayOps(dArr).sum(Numeric$DoubleIsFractional$.MODULE$))), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())));
    }

    public void validateProbabilities(DataFrame dataFrame, NaiveBayesModel naiveBayesModel, String str) {
        Predef$.MODULE$.refArrayOps(dataFrame.collect()).foreach(new NaiveBayesSuite$$anonfun$validateProbabilities$1(this, naiveBayesModel, str));
    }

    public NaiveBayesSuite() {
        BeforeAndAfterAll.class.$init$(this);
        MLlibTestSparkContext.Cclass.$init$(this);
        test("params", Predef$.MODULE$.wrapRefArray(new Tag[0]), new NaiveBayesSuite$$anonfun$9(this));
        test("naive bayes: default params", Predef$.MODULE$.wrapRefArray(new Tag[0]), new NaiveBayesSuite$$anonfun$10(this));
        test("Naive Bayes Multinomial", Predef$.MODULE$.wrapRefArray(new Tag[0]), new NaiveBayesSuite$$anonfun$11(this));
        test("Naive Bayes Bernoulli", Predef$.MODULE$.wrapRefArray(new Tag[0]), new NaiveBayesSuite$$anonfun$13(this));
    }
}
