package org.apache.mahout.clustering;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.lang.NotImplementedException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.clustering.canopy.Canopy;
import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
import org.apache.mahout.clustering.kmeans.Cluster;
import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
import org.apache.mahout.clustering.meanshift.MeanShiftCanopy;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/TestClusterClassifier.class */
public final class TestClusterClassifier extends MahoutTestCase {
    private static ClusterClassifier newDMClassifier() {
        ArrayList arrayList = new ArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        arrayList.add(new DistanceMeasureCluster(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        arrayList.add(new DistanceMeasureCluster(new DenseVector(2), 1, manhattanDistanceMeasure));
        arrayList.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        return new ClusterClassifier(arrayList);
    }

    private static ClusterClassifier newClusterClassifier() {
        ArrayList arrayList = new ArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        arrayList.add(new Cluster(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        arrayList.add(new Cluster(new DenseVector(2), 1, manhattanDistanceMeasure));
        arrayList.add(new Cluster(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        return new ClusterClassifier(arrayList);
    }

    private static ClusterClassifier newSoftClusterClassifier() {
        ArrayList arrayList = new ArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        arrayList.add(new SoftCluster(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        arrayList.add(new SoftCluster(new DenseVector(2), 1, manhattanDistanceMeasure));
        arrayList.add(new SoftCluster(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        return new ClusterClassifier(arrayList);
    }

    private static ClusterClassifier newGaussianClassifier() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new GaussianCluster(new DenseVector(2).assign(1.0d), new DenseVector(2).assign(1.0d), 0));
        arrayList.add(new GaussianCluster(new DenseVector(2), new DenseVector(2).assign(1.0d), 1));
        arrayList.add(new GaussianCluster(new DenseVector(2).assign(-1.0d), new DenseVector(2).assign(1.0d), 2));
        return new ClusterClassifier(arrayList);
    }

    private ClusterClassifier writeAndRead(ClusterClassifier clusterClassifier) throws IOException {
        Configuration configuration = new Configuration();
        Path path = new Path(getTestTempDirPath(), "output");
        FileSystem fileSystem = FileSystem.get(path.toUri(), configuration);
        writeClassifier(clusterClassifier, configuration, path, fileSystem);
        return readClassifier(configuration, path, fileSystem);
    }

    private static void writeClassifier(ClusterClassifier clusterClassifier, Configuration configuration, Path path, FileSystem fileSystem) throws IOException {
        SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, configuration, path, Text.class, ClusterClassifier.class);
        writer.append(new Text("test"), clusterClassifier);
        writer.close();
    }

    private static ClusterClassifier readClassifier(Configuration configuration, Path path, FileSystem fileSystem) throws IOException {
        SequenceFile.Reader reader = new SequenceFile.Reader(fileSystem, path, configuration);
        Text text = new Text();
        ClusterClassifier clusterClassifier = new ClusterClassifier();
        reader.next(text, clusterClassifier);
        reader.close();
        return clusterClassifier;
    }

    @Test
    public void testDMClusterClassification() {
        ClusterClassifier newDMClassifier = newDMClassifier();
        assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(newDMClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(newDMClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testCanopyClassification() {
        ArrayList arrayList = new ArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        arrayList.add(new Canopy(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        arrayList.add(new Canopy(new DenseVector(2), 1, manhattanDistanceMeasure));
        arrayList.add(new Canopy(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        ClusterClassifier clusterClassifier = new ClusterClassifier(arrayList);
        assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(clusterClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(clusterClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testClusterClassification() {
        ClusterClassifier newClusterClassifier = newClusterClassifier();
        assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(newClusterClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(newClusterClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testMSCanopyClassification() {
        ArrayList arrayList = new ArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        arrayList.add(new MeanShiftCanopy(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        arrayList.add(new MeanShiftCanopy(new DenseVector(2), 1, manhattanDistanceMeasure));
        arrayList.add(new MeanShiftCanopy(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        try {
            new ClusterClassifier(arrayList).classify(new DenseVector(2));
            fail("Expected NotImplementedException");
        } catch (NotImplementedException e) {
        }
    }

    @Test
    public void testSoftClusterClassification() {
        ClusterClassifier newSoftClusterClassifier = newSoftClusterClassifier();
        assertEquals("[0,0]", "[0.000, 1.000, 0.000]", AbstractCluster.formatVector(newSoftClusterClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.735, 0.184, 0.082]", AbstractCluster.formatVector(newSoftClusterClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testGaussianClusterClassification() {
        ClusterClassifier newGaussianClassifier = newGaussianClassifier();
        assertEquals("[0,0]", "[0.212, 0.576, 0.212]", AbstractCluster.formatVector(newGaussianClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.952, 0.047, 0.000]", AbstractCluster.formatVector(newGaussianClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testDMClassifierSerialization() throws Exception {
        ClusterClassifier newDMClassifier = newDMClassifier();
        ClusterClassifier writeAndRead = writeAndRead(newDMClassifier);
        assertEquals(newDMClassifier.getModels().size(), writeAndRead.getModels().size());
        assertEquals(((Cluster) newDMClassifier.getModels().get(0)).getClass().getName(), ((Cluster) writeAndRead.getModels().get(0)).getClass().getName());
    }

    @Test
    public void testClusterClassifierSerialization() throws Exception {
        ClusterClassifier newClusterClassifier = newClusterClassifier();
        ClusterClassifier writeAndRead = writeAndRead(newClusterClassifier);
        assertEquals(newClusterClassifier.getModels().size(), writeAndRead.getModels().size());
        assertEquals(((Cluster) newClusterClassifier.getModels().get(0)).getClass().getName(), ((Cluster) writeAndRead.getModels().get(0)).getClass().getName());
    }

    @Test
    public void testSoftClusterClassifierSerialization() throws Exception {
        ClusterClassifier newSoftClusterClassifier = newSoftClusterClassifier();
        ClusterClassifier writeAndRead = writeAndRead(newSoftClusterClassifier);
        assertEquals(newSoftClusterClassifier.getModels().size(), writeAndRead.getModels().size());
        assertEquals(((Cluster) newSoftClusterClassifier.getModels().get(0)).getClass().getName(), ((Cluster) writeAndRead.getModels().get(0)).getClass().getName());
    }

    @Test
    public void testGaussianClassifierSerialization() throws Exception {
        ClusterClassifier newGaussianClassifier = newGaussianClassifier();
        ClusterClassifier writeAndRead = writeAndRead(newGaussianClassifier);
        assertEquals(newGaussianClassifier.getModels().size(), writeAndRead.getModels().size());
        assertEquals(((Cluster) newGaussianClassifier.getModels().get(0)).getClass().getName(), ((Cluster) writeAndRead.getModels().get(0)).getClass().getName());
    }

    @Test
    public void testClusterIteratorKMeans() {
        ClusterClassifier iterate = new ClusterIterator(new KMeansClusteringPolicy()).iterate(TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE), newClusterClassifier(), 5);
        assertEquals(3L, iterate.getModels().size());
        Iterator it = iterate.getModels().iterator();
        while (it.hasNext()) {
            System.out.println(((Cluster) it.next()).asFormatString((String[]) null));
        }
    }

    @Test
    public void testClusterIteratorDirichlet() {
        ClusterClassifier iterate = new ClusterIterator(new DirichletClusteringPolicy(3, 1.0d)).iterate(TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE), newClusterClassifier(), 5);
        assertEquals(3L, iterate.getModels().size());
        Iterator it = iterate.getModels().iterator();
        while (it.hasNext()) {
            System.out.println(((Cluster) it.next()).asFormatString((String[]) null));
        }
    }

    @Test
    public void testSeqFileClusterIteratorKMeans() throws IOException {
        Path testTempDirPath = getTestTempDirPath("points");
        Path testTempDirPath2 = getTestTempDirPath("prior");
        Path testTempDirPath3 = getTestTempDirPath("output");
        Configuration configuration = new Configuration();
        FileSystem fileSystem = FileSystem.get(configuration);
        ClusteringTestUtils.writePointsToFile(TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE), new Path(testTempDirPath, "file1"), fileSystem, configuration);
        Path path = new Path(testTempDirPath2, "priorClassifier");
        ClusterClassifier newClusterClassifier = newClusterClassifier();
        writeClassifier(newClusterClassifier, configuration, path, fileSystem);
        assertEquals(3L, newClusterClassifier.getModels().size());
        System.out.println("Prior");
        Iterator it = newClusterClassifier.getModels().iterator();
        while (it.hasNext()) {
            System.out.println(((Cluster) it.next()).asFormatString((String[]) null));
        }
        new ClusterIterator(new KMeansClusteringPolicy()).iterate(testTempDirPath, path, testTempDirPath3, 5);
        for (int i = 1; i <= 5; i++) {
            System.out.println("Classifier-" + i);
            ClusterClassifier readClassifier = readClassifier(configuration, new Path(testTempDirPath3, "classifier-" + i), fileSystem);
            assertEquals(3L, readClassifier.getModels().size());
            Iterator it2 = readClassifier.getModels().iterator();
            while (it2.hasNext()) {
                System.out.println(((Cluster) it2.next()).asFormatString((String[]) null));
            }
        }
    }
}
