package org.apache.mahout.clustering.lda.cvb;

import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixUtils;
import org.apache.mahout.math.function.DoubleFunction;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.class */
public final class TestCVBModelTrainer extends MahoutTestCase {
    private static final double ETA = 0.1d;
    private static final double ALPHA = 0.1d;

    @Test
    public void testInMemoryCVB0() throws Exception {
        String[] strArr = new String[26];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = String.valueOf((char) (i + 97));
        }
        Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(ClusteringTestUtils.randomStructuredModel(3, 26, new DoubleFunction() { // from class: org.apache.mahout.clustering.lda.cvb.TestCVBModelTrainer.1
            public double apply(double d) {
                return 1.0d / Math.pow(d + 1.0d, 2.0d);
            }
        }), RandomUtils.getRandom(), 100, 20, 1);
        ArrayList newArrayList = Lists.newArrayList();
        for (int i2 = 1; i2 < 2 * 3; i2++) {
            double[] dArr = new double[1];
            for (int i3 = 0; i3 < 1; i3++) {
                InMemoryCollapsedVariationalBayes0 inMemoryCollapsedVariationalBayes0 = new InMemoryCollapsedVariationalBayes0(sampledCorpus, strArr, i2, 0.1d, 0.1d, 2, 1, 0.0d);
                inMemoryCollapsedVariationalBayes0.setVerbose(true);
                dArr[i3] = inMemoryCollapsedVariationalBayes0.iterateUntilConvergence(0.0d, 5, 0, 0.2d);
                System.out.println(dArr[i3]);
            }
            Arrays.sort(dArr);
            System.out.println(Arrays.toString(dArr));
            newArrayList.add(Double.valueOf(dArr[0]));
        }
        System.out.println(Joiner.on(",").join(newArrayList));
    }

    @Test
    public void testRandomStructuredModelViaMR() throws Exception {
        Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(ClusteringTestUtils.randomStructuredModel(3, 9, new DoubleFunction() { // from class: org.apache.mahout.clustering.lda.cvb.TestCVBModelTrainer.2
            public double apply(double d) {
                return 1.0d / Math.pow(d + 1.0d, 3.0d);
            }
        }), RandomUtils.getRandom(1234L), 500, 10, 1);
        Path testTempDirPath = getTestTempDirPath("corpus");
        MatrixUtils.write(testTempDirPath, getConfiguration(), sampledCorpus);
        ArrayList newArrayList = Lists.newArrayList();
        int i = 3 - 1;
        for (int i2 = i; i2 < 3 + 2; i2++) {
            Path testTempDirPath2 = getTestTempDirPath("topicTemp" + i2);
            Configuration configuration = getConfiguration();
            new CVB0Driver().run(configuration, testTempDirPath, (Path) null, i2, 9, 0.1d, 0.1d, 5, 1, 0.0d, (Path) null, (Path) null, testTempDirPath2, 1234L, 0.2f, 2, 1, 3, 1, false);
            newArrayList.add(Double.valueOf(lowestPerplexity(configuration, testTempDirPath2)));
        }
        int i3 = -1;
        double d = Double.MAX_VALUE;
        for (int i4 = 0; i4 < newArrayList.size(); i4++) {
            if (((Double) newArrayList.get(i4)).doubleValue() < d) {
                d = ((Double) newArrayList.get(i4)).doubleValue();
                i3 = i4 + i;
            }
        }
        assertEquals("The optimal number of topics is not that of the generating distribution", 4L, i3);
        System.out.println("Perplexities: " + Joiner.on(", ").join(newArrayList));
    }

    private static double lowestPerplexity(Configuration configuration, Path path) throws IOException {
        double d = Double.MAX_VALUE;
        int i = 2;
        while (true) {
            double readPerplexity = CVB0Driver.readPerplexity(configuration, path, i);
            if (Double.isNaN(readPerplexity)) {
                return d;
            }
            d = Math.min(readPerplexity, d);
            i++;
        }
    }
}
