package org.apache.spark.mllib.clustering;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;
import scala.Tuple2;
import scala.Tuple3;

/* loaded from: input_file:org/apache/spark/mllib/clustering/JavaLDASuite.class */
public class JavaLDASuite extends SharedSparkSession {
    private static int tinyK = LDASuite.tinyK();
    private static int tinyVocabSize = LDASuite.tinyVocabSize();
    private static Matrix tinyTopics = LDASuite.tinyTopics();
    private static Tuple2<int[], double[]>[] tinyTopicDescription = LDASuite.tinyTopicDescription();
    private JavaPairRDD<Long, Vector> corpus;
    private LocalLDAModel toyModel = LDASuite.toyModel();
    private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();

    @Override // org.apache.spark.SharedSparkSession
    public void setUp() throws IOException {
        super.setUp();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
            arrayList.add(new Tuple2((Long) LDASuite.tinyCorpus()[i]._1(), LDASuite.tinyCorpus()[i]._2()));
        }
        this.corpus = JavaPairRDD.fromJavaRDD(this.jsc.parallelize(arrayList, 2));
    }

    @Test
    public void localLDAModel() {
        Matrix tinyTopics2 = LDASuite.tinyTopics();
        double[] dArr = new double[tinyTopics2.numRows()];
        Arrays.fill(dArr, 1.0d / tinyTopics2.numRows());
        LocalLDAModel localLDAModel = new LocalLDAModel(tinyTopics2, Vectors.dense(dArr), 1.0d, 100.0d);
        Assert.assertEquals(localLDAModel.k(), tinyK);
        Assert.assertEquals(localLDAModel.vocabSize(), tinyVocabSize);
        Assert.assertEquals(localLDAModel.topicsMatrix(), tinyTopics);
        Tuple2[] describeTopics = localLDAModel.describeTopics();
        Assert.assertEquals(describeTopics.length, tinyK);
        for (int i = 0; i < describeTopics.length; i++) {
            Assert.assertArrayEquals((int[]) describeTopics[i]._1(), (int[]) tinyTopicDescription[i]._1());
            Assert.assertArrayEquals((double[]) describeTopics[i]._2(), (double[]) tinyTopicDescription[i]._2(), 1.0E-5d);
        }
    }

    @Test
    public void distributedLDAModel() {
        LDA lda = new LDA();
        lda.setK(3).setDocConcentration(1.2d).setTopicConcentration(1.2d).setMaxIterations(5).setSeed(12345L);
        DistributedLDAModel run = lda.run(this.corpus);
        LocalLDAModel local = run.toLocal();
        Assert.assertEquals(run.k(), 3);
        Assert.assertEquals(local.k(), 3);
        Assert.assertEquals(run.vocabSize(), tinyVocabSize);
        Assert.assertEquals(local.vocabSize(), tinyVocabSize);
        Assert.assertEquals(run.topicsMatrix(), local.topicsMatrix());
        Assert.assertEquals(run.describeTopics().length, 3);
        Assert.assertEquals(local.describeTopics().length, 3);
        Assert.assertTrue(run.logLikelihood() < 0.0d);
        Assert.assertTrue(run.logPrior() < 0.0d);
        Assert.assertEquals(run.javaTopicDistributions().count(), this.corpus.filter(new Function<Tuple2<Long, Vector>, Boolean>() { // from class: org.apache.spark.mllib.clustering.JavaLDASuite.1
            public Boolean call(Tuple2<Long, Vector> tuple2) {
                return Boolean.valueOf(Vectors.norm((Vector) tuple2._2(), 1.0d) != 0.0d);
            }
        }).count());
        Tuple3 tuple3 = (Tuple3) run.javaTopTopicsPerDocument(3).first();
        int[] iArr = (int[]) tuple3._2();
        double[] dArr = (double[]) tuple3._3();
        Assert.assertEquals(3L, iArr.length);
        Assert.assertEquals(3L, dArr.length);
        Assert.assertEquals(((int[]) r0._2()).length, ((int[]) r0._3()).length);
    }

    @Test
    public void onlineOptimizerCompatibility() {
        OnlineLDAOptimizer miniBatchFraction = new OnlineLDAOptimizer().setTau0(1024.0d).setKappa(0.51d).setGammaShape(1.0E40d).setMiniBatchFraction(0.5d);
        LDA lda = new LDA();
        lda.setK(3).setDocConcentration(1.2d).setTopicConcentration(1.2d).setMaxIterations(5).setSeed(12345L).setOptimizer(miniBatchFraction);
        LDAModel run = lda.run(this.corpus);
        Assert.assertEquals(run.k(), 3);
        Assert.assertEquals(run.vocabSize(), tinyVocabSize);
        Assert.assertEquals(run.describeTopics().length, 3);
        Assert.assertEquals(run.describeTopics().length, 3);
    }

    @Test
    public void localLdaMethods() {
        JavaPairRDD fromJavaRDD = JavaPairRDD.fromJavaRDD(this.jsc.parallelize(this.toyData, 2));
        Assert.assertEquals(this.toyModel.topicDistributions(fromJavaRDD).count(), fromJavaRDD.count());
        this.toyModel.logPerplexity(fromJavaRDD);
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Tuple2(0L, Vectors.dense(1.0d, new double[]{0.0d, 0.0d})));
        this.toyModel.logLikelihood(JavaPairRDD.fromJavaRDD(this.jsc.parallelize(arrayList)));
    }
}
