package org.apache.mahout.clustering.lda;

import java.util.Iterator;
import java.util.Random;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.PoissonDistributionImpl;
import org.apache.mahout.clustering.lda.LDAInference;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/lda/TestLDAInference.class */
public final class TestLDAInference extends MahoutTestCase {
    private static final int NUM_TOPICS = 20;
    private Random random;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.random = RandomUtils.getRandom();
    }

    private Vector generateRandomDoc(int i, double d) throws MathException {
        DenseVector denseVector = new DenseVector(i);
        PoissonDistributionImpl poissonDistributionImpl = new PoissonDistributionImpl(d);
        for (int i2 = 0; i2 < i; i2++) {
            denseVector.setQuick(i2, poissonDistributionImpl.inverseCumulativeProbability(this.random.nextDouble()) + 1);
        }
        return denseVector;
    }

    private LDAState generateRandomState(int i, int i2) {
        double d = 50.0d / i2;
        DenseMatrix denseMatrix = new DenseMatrix(i2, i);
        double[] dArr = new double[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            double d2 = 0.0d;
            for (int i4 = 0; i4 < i; i4++) {
                double nextDouble = this.random.nextDouble() + 1.0E-10d;
                d2 += nextDouble;
                denseMatrix.setQuick(i3, i4, Math.log(nextDouble));
            }
            dArr[i3] = Math.log(d2);
        }
        return new LDAState(i2, i, d, denseMatrix, dArr, Double.NEGATIVE_INFINITY);
    }

    private void runTest(int i, double d, int i2) throws MathException {
        LDAInference lDAInference = new LDAInference(generateRandomState(i, NUM_TOPICS));
        for (int i3 = 0; i3 < i2; i3++) {
            Vector generateRandomDoc = generateRandomDoc(i, d);
            LDAInference.InferredDocument infer = lDAInference.infer(generateRandomDoc);
            assertEquals("wordCounts", infer.getWordCounts(), generateRandomDoc);
            assertNotNull("gamma", infer.getGamma());
            Iterator iterateNonZero = generateRandomDoc.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                int index = ((Vector.Element) iterateNonZero.next()).index();
                for (int i4 = 0; i4 < NUM_TOPICS; i4++) {
                    double phi = infer.phi(i4, index);
                    assertTrue(i4 + " " + index + " logProb " + phi, phi <= 0.0d);
                }
            }
            assertTrue("log likelihood", infer.getLogLikelihood() <= 1.0E-10d);
        }
    }

    @Test
    public void testLDAEasy() throws Exception {
        runTest(10, 1.0d, 5);
    }

    @Test
    public void testLDASparse() throws Exception {
        runTest(100, 0.4d, 5);
    }

    @Test
    public void testLDADense() throws Exception {
        runTest(100, 3.0d, 5);
    }
}
