package org.apache.mahout.clustering.dirichlet;

import com.google.common.collect.Lists;
import java.util.List;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/dirichlet/TestDirichletClustering.class */
public final class TestDirichletClustering extends MahoutTestCase {
    private List<VectorWritable> sampleData;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.sampleData = Lists.newArrayList();
    }

    private void generateSamples(int i, double d, double d2, double d3, int i2) {
        System.out.println("Generating " + i + " samples m=[" + d + ", " + d2 + "] sd=" + d3);
        for (int i3 = 0; i3 < i; i3++) {
            DenseVector denseVector = new DenseVector(i2);
            for (int i4 = 0; i4 < i2; i4++) {
                denseVector.set(i4, UncommonDistributions.rNorm(d, d3));
            }
            this.sampleData.add(new VectorWritable(denseVector));
        }
    }

    private void generateSamples(int i, double d, double d2, double d3) {
        generateSamples(i, d, d2, d3, 2);
    }

    private static void printResults(Iterable<Cluster[]> iterable, int i) {
        int i2 = 0;
        for (Cluster[] clusterArr : iterable) {
            int i3 = i2;
            i2++;
            System.out.print("sample[" + i3 + "]= ");
            for (Cluster cluster : clusterArr) {
                if (cluster.count() > i) {
                    System.out.print(cluster.asFormatString((String[]) null) + ", ");
                }
            }
            System.out.println();
        }
        System.out.println();
    }

    @Test
    public void testDirichletCluster100() {
        System.out.println("testDirichletCluster100");
        generateSamples(40, 1.0d, 1.0d, 3.0d);
        generateSamples(30, 1.0d, 0.0d, 0.1d);
        generateSamples(30, 0.0d, 1.0d, 0.1d);
        List cluster = new DirichletClusterer(this.sampleData, new GaussianClusterDistribution(new VectorWritable(new DenseVector(2))), 1.0d, 10, 1, 0).cluster(30);
        printResults(cluster, 2);
        assertNotNull(cluster);
    }

    @Test
    public void testDirichletGaussianCluster100() {
        System.out.println("testDirichletGaussianCluster100");
        generateSamples(40, 1.0d, 1.0d, 3.0d);
        generateSamples(30, 1.0d, 0.0d, 0.1d);
        generateSamples(30, 0.0d, 1.0d, 0.1d);
        List cluster = new DirichletClusterer(this.sampleData, new GaussianClusterDistribution(new VectorWritable(new DenseVector(2))), 1.0d, 10, 1, 0).cluster(30);
        printResults(cluster, 2);
        assertNotNull(cluster);
    }

    @Test
    public void testDirichletDMCluster100() {
        System.out.println("testDirichletDMCluster100");
        generateSamples(40, 1.0d, 1.0d, 3.0d);
        generateSamples(30, 1.0d, 0.0d, 0.1d);
        generateSamples(30, 0.0d, 1.0d, 0.1d);
        List cluster = new DirichletClusterer(this.sampleData, new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2))), 1.0d, 10, 1, 0).cluster(30);
        printResults(cluster, 2);
        assertNotNull(cluster);
    }
}
