package org.apache.mahout.clustering.iterator;

import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.Canopy;
import org.apache.mahout.clustering.classify.ClusterClassifier;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
import org.apache.mahout.clustering.kmeans.Kluster;
import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/iterator/TestClusterClassifier.class */
public final class TestClusterClassifier extends MahoutTestCase {
    private static ClusterClassifier newDMClassifier() {
        ArrayList newArrayList = Lists.newArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        newArrayList.add(new DistanceMeasureCluster(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        newArrayList.add(new DistanceMeasureCluster(new DenseVector(2), 1, manhattanDistanceMeasure));
        newArrayList.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        return new ClusterClassifier(newArrayList, new KMeansClusteringPolicy());
    }

    private static ClusterClassifier newKlusterClassifier() {
        ArrayList newArrayList = Lists.newArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        newArrayList.add(new Kluster(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        newArrayList.add(new Kluster(new DenseVector(2), 1, manhattanDistanceMeasure));
        newArrayList.add(new Kluster(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        return new ClusterClassifier(newArrayList, new KMeansClusteringPolicy());
    }

    private static ClusterClassifier newCosineKlusterClassifier() {
        ArrayList newArrayList = Lists.newArrayList();
        CosineDistanceMeasure cosineDistanceMeasure = new CosineDistanceMeasure();
        newArrayList.add(new Kluster(new DenseVector(2).assign(1.0d), 0, cosineDistanceMeasure));
        newArrayList.add(new Kluster(new DenseVector(2), 1, cosineDistanceMeasure));
        newArrayList.add(new Kluster(new DenseVector(2).assign(-1.0d), 2, cosineDistanceMeasure));
        return new ClusterClassifier(newArrayList, new KMeansClusteringPolicy());
    }

    private static ClusterClassifier newSoftClusterClassifier() {
        ArrayList newArrayList = Lists.newArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        newArrayList.add(new SoftCluster(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        newArrayList.add(new SoftCluster(new DenseVector(2), 1, manhattanDistanceMeasure));
        newArrayList.add(new SoftCluster(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        return new ClusterClassifier(newArrayList, new FuzzyKMeansClusteringPolicy());
    }

    private ClusterClassifier writeAndRead(ClusterClassifier clusterClassifier) throws IOException {
        Path path = new Path(getTestTempDirPath(), "output");
        clusterClassifier.writeToSeqFiles(path);
        ClusterClassifier clusterClassifier2 = new ClusterClassifier();
        clusterClassifier2.readFromSeqFiles(getConfiguration(), path);
        return clusterClassifier2;
    }

    @Test
    public void testDMClusterClassification() {
        ClusterClassifier newDMClassifier = newDMClassifier();
        assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(newDMClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.493, 0.296, 0.211]", AbstractCluster.formatVector(newDMClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testCanopyClassification() {
        ArrayList newArrayList = Lists.newArrayList();
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        newArrayList.add(new Canopy(new DenseVector(2).assign(1.0d), 0, manhattanDistanceMeasure));
        newArrayList.add(new Canopy(new DenseVector(2), 1, manhattanDistanceMeasure));
        newArrayList.add(new Canopy(new DenseVector(2).assign(-1.0d), 2, manhattanDistanceMeasure));
        ClusterClassifier clusterClassifier = new ClusterClassifier(newArrayList, new CanopyClusteringPolicy());
        assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(clusterClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.493, 0.296, 0.211]", AbstractCluster.formatVector(clusterClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testClusterClassification() {
        ClusterClassifier newKlusterClassifier = newKlusterClassifier();
        assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(newKlusterClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.493, 0.296, 0.211]", AbstractCluster.formatVector(newKlusterClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testSoftClusterClassification() {
        ClusterClassifier newSoftClusterClassifier = newSoftClusterClassifier();
        assertEquals("[0,0]", "[0.000, 1.000, 0.000]", AbstractCluster.formatVector(newSoftClusterClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.735, 0.184, 0.082]", AbstractCluster.formatVector(newSoftClusterClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }

    @Test
    public void testDMClassifierSerialization() throws Exception {
        ClusterClassifier newDMClassifier = newDMClassifier();
        ClusterClassifier writeAndRead = writeAndRead(newDMClassifier);
        assertEquals(newDMClassifier.getModels().size(), writeAndRead.getModels().size());
        assertEquals(((Cluster) newDMClassifier.getModels().get(0)).getClass().getName(), ((Cluster) writeAndRead.getModels().get(0)).getClass().getName());
    }

    @Test
    public void testClusterClassifierSerialization() throws Exception {
        ClusterClassifier newKlusterClassifier = newKlusterClassifier();
        ClusterClassifier writeAndRead = writeAndRead(newKlusterClassifier);
        assertEquals(newKlusterClassifier.getModels().size(), writeAndRead.getModels().size());
        assertEquals(((Cluster) newKlusterClassifier.getModels().get(0)).getClass().getName(), ((Cluster) writeAndRead.getModels().get(0)).getClass().getName());
    }

    @Test
    public void testSoftClusterClassifierSerialization() throws Exception {
        ClusterClassifier newSoftClusterClassifier = newSoftClusterClassifier();
        ClusterClassifier writeAndRead = writeAndRead(newSoftClusterClassifier);
        assertEquals(newSoftClusterClassifier.getModels().size(), writeAndRead.getModels().size());
        assertEquals(((Cluster) newSoftClusterClassifier.getModels().get(0)).getClass().getName(), ((Cluster) writeAndRead.getModels().get(0)).getClass().getName());
    }

    @Test
    public void testClusterIteratorKMeans() {
        ClusterClassifier iterate = ClusterIterator.iterate(TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE), newKlusterClassifier(), 5);
        assertEquals(3L, iterate.getModels().size());
        Iterator it = iterate.getModels().iterator();
        while (it.hasNext()) {
            System.out.println(((Cluster) it.next()).asFormatString((String[]) null));
        }
    }

    @Test
    public void testClusterIteratorDirichlet() {
        ClusterClassifier iterate = ClusterIterator.iterate(TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE), newKlusterClassifier(), 5);
        assertEquals(3L, iterate.getModels().size());
        Iterator it = iterate.getModels().iterator();
        while (it.hasNext()) {
            System.out.println(((Cluster) it.next()).asFormatString((String[]) null));
        }
    }

    @Test
    public void testSeqFileClusterIteratorKMeans() throws IOException {
        Path testTempDirPath = getTestTempDirPath("points");
        Path testTempDirPath2 = getTestTempDirPath("prior");
        Path testTempDirPath3 = getTestTempDirPath("output");
        Configuration configuration = getConfiguration();
        ClusteringTestUtils.writePointsToFile(TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE), new Path(testTempDirPath, "file1"), FileSystem.get(testTempDirPath.toUri(), configuration), configuration);
        Path path = new Path(testTempDirPath2, "priorClassifier");
        ClusterClassifier newKlusterClassifier = newKlusterClassifier();
        newKlusterClassifier.writeToSeqFiles(path);
        assertEquals(3L, newKlusterClassifier.getModels().size());
        System.out.println("Prior");
        Iterator it = newKlusterClassifier.getModels().iterator();
        while (it.hasNext()) {
            System.out.println(((Cluster) it.next()).asFormatString((String[]) null));
        }
        ClusterIterator.iterateSeq(configuration, testTempDirPath, path, testTempDirPath3, 5);
        int i = 1;
        while (i <= 4) {
            System.out.println("Classifier-" + i);
            ClusterClassifier clusterClassifier = new ClusterClassifier();
            clusterClassifier.readFromSeqFiles(configuration, new Path(testTempDirPath3, i == 4 ? "clusters-4-final" : "clusters-" + i));
            assertEquals(3L, clusterClassifier.getModels().size());
            Iterator it2 = clusterClassifier.getModels().iterator();
            while (it2.hasNext()) {
                System.out.println(((Cluster) it2.next()).asFormatString((String[]) null));
            }
            i++;
        }
    }

    @Test
    public void testMRFileClusterIteratorKMeans() throws Exception {
        Path testTempDirPath = getTestTempDirPath("points");
        Path testTempDirPath2 = getTestTempDirPath("prior");
        Path testTempDirPath3 = getTestTempDirPath("output");
        Configuration configuration = getConfiguration();
        ClusteringTestUtils.writePointsToFile(TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE), new Path(testTempDirPath, "file1"), FileSystem.get(testTempDirPath.toUri(), configuration), configuration);
        Path path = new Path(testTempDirPath2, "priorClassifier");
        ClusterClassifier newKlusterClassifier = newKlusterClassifier();
        newKlusterClassifier.writeToSeqFiles(path);
        ClusterClassifier.writePolicy(new KMeansClusteringPolicy(), path);
        assertEquals(3L, newKlusterClassifier.getModels().size());
        System.out.println("Prior");
        Iterator it = newKlusterClassifier.getModels().iterator();
        while (it.hasNext()) {
            System.out.println(((Cluster) it.next()).asFormatString((String[]) null));
        }
        ClusterIterator.iterateMR(configuration, testTempDirPath, path, testTempDirPath3, 5);
        int i = 1;
        while (i <= 4) {
            System.out.println("Classifier-" + i);
            ClusterClassifier clusterClassifier = new ClusterClassifier();
            clusterClassifier.readFromSeqFiles(configuration, new Path(testTempDirPath3, i == 4 ? "clusters-4-final" : "clusters-" + i));
            assertEquals(3L, clusterClassifier.getModels().size());
            Iterator it2 = clusterClassifier.getModels().iterator();
            while (it2.hasNext()) {
                System.out.println(((Cluster) it2.next()).asFormatString((String[]) null));
            }
            i++;
        }
    }

    @Test
    public void testCosineKlusterClassification() {
        ClusterClassifier newCosineKlusterClassifier = newCosineKlusterClassifier();
        assertEquals("[0,0]", "[0.333, 0.333, 0.333]", AbstractCluster.formatVector(newCosineKlusterClassifier.classify(new DenseVector(2)), (String[]) null));
        assertEquals("[2,2]", "[0.429, 0.429, 0.143]", AbstractCluster.formatVector(newCosineKlusterClassifier.classify(new DenseVector(2).assign(2.0d)), (String[]) null));
    }
}
