package org.apache.mahout.clustering.kmeans;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.CanopyDriver;
import org.apache.mahout.common.DummyOutputCollector;
import org.apache.mahout.common.DummyReporter;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.AbstractVector;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:org/apache/mahout/clustering/kmeans/TestKmeansClustering.class */
public class TestKmeansClustering extends MahoutTestCase {
    public static final double[][] reference = {new double[]{1.0d, 1.0d}, new double[]{2.0d, 1.0d}, new double[]{1.0d, 2.0d}, new double[]{2.0d, 2.0d}, new double[]{3.0d, 3.0d}, new double[]{4.0d, 4.0d}, new double[]{5.0d, 4.0d}, new double[]{4.0d, 5.0d}, new double[]{5.0d, 5.0d}};
    private static final int[][] expectedNumPoints = {new int[]{9}, new int[]{4, 5}, new int[]{4, 4, 1}, new int[]{1, 2, 1, 5}, new int[]{1, 1, 1, 2, 4}, new int[]{1, 1, 1, 1, 1, 4}, new int[]{1, 1, 1, 1, 1, 2, 2}, new int[]{1, 1, 1, 1, 1, 1, 2, 1}, new int[]{1, 1, 1, 1, 1, 1, 1, 1, 1}};
    private FileSystem fs;

    private static void rmr(String str) {
        File file = new File(str);
        if (file.exists()) {
            if (file.isDirectory()) {
                for (String str2 : file.list()) {
                    rmr(file.toString() + File.separator + str2);
                }
            }
            file.delete();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.mahout.common.MahoutTestCase
    public void setUp() throws Exception {
        super.setUp();
        rmr("output");
        rmr("testdata");
        this.fs = FileSystem.get(new Configuration());
    }

    public static List<VectorWritable> getPointsWritable(double[][] dArr) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (double[] dArr2 : dArr) {
            int i2 = i;
            i++;
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(String.valueOf(i2), dArr2.length);
            randomAccessSparseVector.assign(dArr2);
            arrayList.add(new VectorWritable(randomAccessSparseVector));
        }
        return arrayList;
    }

    public static List<Vector> getPoints(double[][] dArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr2 = dArr[i];
            SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(String.valueOf(i), dArr2.length);
            sequentialAccessSparseVector.assign(dArr2);
            arrayList.add(sequentialAccessSparseVector);
        }
        return arrayList;
    }

    public void testReferenceImplementation() throws Exception {
        List<Vector> points = getPoints(reference);
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        for (int i = 0; i < points.size(); i++) {
            System.out.println("Test k=" + (i + 1) + ':');
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                arrayList.add(new VisibleCluster(points.get(i2)));
            }
            List clusterPoints = KMeansClusterer.clusterPoints(points, arrayList, euclideanDistanceMeasure, 10, 0.001d);
            List list = (List) clusterPoints.get(clusterPoints.size() - 1);
            for (int i3 = 0; i3 < list.size(); i3++) {
                Cluster cluster = (Cluster) list.get(i3);
                System.out.println(cluster.toString());
                assertEquals("Cluster " + i3 + " test " + (i + 1), expectedNumPoints[i][i3], cluster.getNumPoints());
            }
        }
    }

    public void testStd() {
        List<Vector> points = getPoints(reference);
        Cluster cluster = new Cluster(points.get(0));
        Iterator<Vector> it = points.iterator();
        while (it.hasNext()) {
            cluster.addPoint(it.next());
            if (cluster.getNumPoints() > 1) {
                assertTrue(cluster.getStd() > 0.0d);
            }
        }
    }

    private static Map<String, Cluster> loadClusterMap(List<Cluster> list) {
        HashMap hashMap = new HashMap();
        for (Cluster cluster : list) {
            hashMap.put(cluster.getIdentifier(), cluster);
        }
        return hashMap;
    }

    public void testKMeansMapper() throws Exception {
        KMeansMapper kMeansMapper = new KMeansMapper();
        JobConf jobConf = new JobConf();
        jobConf.set("org.apache.mahout.clustering.kmeans.measure", "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
        jobConf.set("org.apache.mahout.clustering.kmeans.convergence", "0.001");
        jobConf.set("org.apache.mahout.clustering.kmeans.path", "");
        kMeansMapper.configure(jobConf);
        List<VectorWritable> pointsWritable = getPointsWritable(reference);
        for (int i = 0; i < pointsWritable.size(); i++) {
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                Cluster cluster = new Cluster(pointsWritable.get(i2).get(), i2);
                cluster.addPoint(cluster.getCenter());
                arrayList.add(cluster);
            }
            kMeansMapper.config(arrayList);
            Iterator<VectorWritable> it = pointsWritable.iterator();
            while (it.hasNext()) {
                kMeansMapper.map(new Text(), it.next(), dummyOutputCollector, (Reporter) null);
            }
            assertEquals("Number of map results", i + 1, dummyOutputCollector.getData().size());
            EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
            Map<String, Cluster> loadClusterMap = loadClusterMap(arrayList);
            for (String str : dummyOutputCollector.getKeys()) {
                Cluster cluster2 = loadClusterMap.get(str);
                for (KMeansInfo kMeansInfo : dummyOutputCollector.getValue(str)) {
                    double distance = euclideanDistanceMeasure.distance(cluster2.getCenter(), kMeansInfo.getPointTotal());
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        assertTrue("distance error", distance <= euclideanDistanceMeasure.distance(kMeansInfo.getPointTotal(), ((Cluster) it2.next()).getCenter()));
                    }
                }
            }
        }
    }

    public void testKMeansCombiner() throws Exception {
        KMeansMapper kMeansMapper = new KMeansMapper();
        JobConf jobConf = new JobConf();
        jobConf.set("org.apache.mahout.clustering.kmeans.measure", "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
        jobConf.set("org.apache.mahout.clustering.kmeans.convergence", "0.001");
        jobConf.set("org.apache.mahout.clustering.kmeans.path", "");
        kMeansMapper.configure(jobConf);
        List<VectorWritable> pointsWritable = getPointsWritable(reference);
        for (int i = 0; i < pointsWritable.size(); i++) {
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                Cluster cluster = new Cluster(pointsWritable.get(i2).get(), i2);
                cluster.addPoint(cluster.getCenter());
                arrayList.add(cluster);
            }
            kMeansMapper.config(arrayList);
            Iterator<VectorWritable> it = pointsWritable.iterator();
            while (it.hasNext()) {
                kMeansMapper.map(new Text(), it.next(), dummyOutputCollector, (Reporter) null);
            }
            KMeansCombiner kMeansCombiner = new KMeansCombiner();
            DummyOutputCollector dummyOutputCollector2 = new DummyOutputCollector();
            for (String str : dummyOutputCollector.getKeys()) {
                kMeansCombiner.reduce(new Text(str), dummyOutputCollector.getValue(str).iterator(), dummyOutputCollector2, (Reporter) null);
            }
            assertEquals("Number of map results", i + 1, dummyOutputCollector2.getData().size());
            int i3 = 0;
            Vector denseVector = new DenseVector(2);
            Iterator<String> it2 = dummyOutputCollector2.getKeys().iterator();
            while (it2.hasNext()) {
                List value = dummyOutputCollector2.getValue(it2.next());
                assertEquals("too many values", 1, value.size());
                KMeansInfo kMeansInfo = (KMeansInfo) value.get(0);
                i3 += kMeansInfo.getPoints();
                denseVector = denseVector.plus(kMeansInfo.getPointTotal());
            }
            assertEquals("total points", 9, i3);
            assertEquals("point total[0]", 27, (int) denseVector.get(0));
            assertEquals("point total[1]", 27, (int) denseVector.get(1));
        }
    }

    public void testKMeansReducer() throws Exception {
        KMeansMapper kMeansMapper = new KMeansMapper();
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        JobConf jobConf = new JobConf();
        jobConf.set("org.apache.mahout.clustering.kmeans.measure", "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
        jobConf.set("org.apache.mahout.clustering.kmeans.convergence", "0.001");
        jobConf.set("org.apache.mahout.clustering.kmeans.path", "");
        kMeansMapper.configure(jobConf);
        List<VectorWritable> pointsWritable = getPointsWritable(reference);
        for (int i = 0; i < pointsWritable.size(); i++) {
            System.out.println("K = " + i);
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                arrayList.add(new Cluster(pointsWritable.get(i2).get(), i2));
            }
            kMeansMapper.config(arrayList);
            Iterator<VectorWritable> it = pointsWritable.iterator();
            while (it.hasNext()) {
                kMeansMapper.map(new Text(), it.next(), dummyOutputCollector, (Reporter) null);
            }
            KMeansCombiner kMeansCombiner = new KMeansCombiner();
            DummyOutputCollector dummyOutputCollector2 = new DummyOutputCollector();
            for (String str : dummyOutputCollector.getKeys()) {
                kMeansCombiner.reduce(new Text(str), dummyOutputCollector.getValue(str).iterator(), dummyOutputCollector2, (Reporter) null);
            }
            KMeansReducer kMeansReducer = new KMeansReducer();
            kMeansReducer.configure(jobConf);
            kMeansReducer.config(arrayList);
            DummyOutputCollector dummyOutputCollector3 = new DummyOutputCollector();
            for (String str2 : dummyOutputCollector2.getKeys()) {
                kMeansReducer.reduce(new Text(str2), dummyOutputCollector2.getValue(str2).iterator(), dummyOutputCollector3, new DummyReporter());
            }
            assertEquals("Number of map results", i + 1, dummyOutputCollector3.getData().size());
            ArrayList arrayList2 = new ArrayList();
            for (int i3 = 0; i3 < i + 1; i3++) {
                arrayList2.add(new Cluster(pointsWritable.get(i3).get(), i3));
            }
            ArrayList arrayList3 = new ArrayList();
            Iterator<VectorWritable> it2 = pointsWritable.iterator();
            while (it2.hasNext()) {
                arrayList3.add(it2.next().get());
            }
            boolean runKMeansIteration = KMeansClusterer.runKMeansIteration(arrayList3, arrayList2, euclideanDistanceMeasure, 0.001d);
            if (i == 8) {
                assertTrue("not converged? " + i, runKMeansIteration);
            } else {
                assertFalse("converged? " + i, runKMeansIteration);
            }
            boolean z = true;
            for (int i4 = 0; i4 < arrayList2.size(); i4++) {
                Cluster cluster = (Cluster) arrayList2.get(i4);
                Cluster cluster2 = (Cluster) dummyOutputCollector3.getValue(cluster.getIdentifier()).get(0);
                z = z && cluster2.isConverged();
                cluster2.recomputeCenter();
                assertTrue(i4 + " reference center: " + cluster.getCenter().asFormatString() + " and cluster center:  " + cluster2.getCenter().asFormatString() + " are not equal", AbstractVector.equivalent(cluster.getCenter(), cluster2.getCenter()));
            }
            if (i == 8) {
                assertTrue("not converged? " + i, z);
            } else {
                assertFalse("converged? " + i, z);
            }
        }
    }

    public void testKMeansMRJob() throws Exception {
        List<VectorWritable> pointsWritable = getPointsWritable(reference);
        File file = new File("testdata");
        if (!file.exists()) {
            file.mkdir();
        }
        File file2 = new File("testdata/points");
        if (!file2.exists()) {
            file2.mkdir();
        }
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, "testdata/points/file1", this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, "testdata/points/file2", this.fs, configuration);
        for (int i = 1; i < pointsWritable.size(); i++) {
            System.out.println("testKMeansMRJob k= " + i);
            JobConf jobConf = new JobConf(KMeansDriver.class);
            Path path = new Path("testdata/clusters/part-00000");
            FileSystem fileSystem = FileSystem.get(path.toUri(), jobConf);
            SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, jobConf, path, Text.class, Cluster.class);
            for (int i2 = 0; i2 < i + 1; i2++) {
                Cluster cluster = new Cluster(pointsWritable.get(i2).get(), i2);
                cluster.addPoint(cluster.getCenter());
                writer.append(new Text(cluster.getIdentifier()), cluster);
            }
            writer.close();
            HadoopUtil.overwriteOutput("output");
            KMeansDriver.runJob("testdata/points", "testdata/clusters", "output", EuclideanDistanceMeasure.class.getName(), 0.001d, 10, i + 1);
            assertTrue("output dir exists?", new File("output/points").exists());
            SequenceFile.Reader reader = new SequenceFile.Reader(fileSystem, new Path("output/points/part-00000"), configuration);
            int[] iArr = expectedNumPoints[i];
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            Text text = new Text();
            Text text2 = new Text();
            while (true) {
                Text text3 = text2;
                if (!reader.next(text, text3)) {
                    break;
                }
                dummyOutputCollector.collect((DummyOutputCollector) text3, text);
                text = new Text();
                text2 = new Text();
            }
            reader.close();
            if (i == 2) {
                assertEquals("clusters[" + i + ']', iArr.length - 1, dummyOutputCollector.getKeys().size());
            } else {
                assertEquals("clusters[" + i + ']', iArr.length, dummyOutputCollector.getKeys().size());
            }
        }
    }

    public void testKMeansWithCanopyClusterInput() throws Exception {
        List<VectorWritable> pointsWritable = getPointsWritable(reference);
        File file = new File("testdata");
        if (!file.exists()) {
            file.mkdir();
        }
        File file2 = new File("testdata/points");
        if (!file2.exists()) {
            file2.mkdir();
        }
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, "testdata/points/file1", this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, "testdata/points/file2", this.fs, configuration);
        CanopyDriver.runJob("testdata/points", "testdata/canopies", ManhattanDistanceMeasure.class.getName(), 3.1d, 2.1d);
        KMeansDriver.runJob("testdata/points", "testdata/canopies", "output", EuclideanDistanceMeasure.class.getName(), 0.001d, 10, 1);
        File file3 = new File("output/points");
        assertTrue("output dir exists?", file3.exists());
        assertEquals("output dir files?", 4, file3.list().length);
        DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
        SequenceFile.Reader reader = new SequenceFile.Reader(this.fs, new Path("output/points/part-00000"), configuration);
        Text text = new Text();
        Text text2 = new Text();
        while (true) {
            Text text3 = text2;
            if (!reader.next(text, text3)) {
                reader.close();
                assertEquals("num points[0]", 4, dummyOutputCollector.getValue("0").size());
                assertEquals("num points[1]", 5, dummyOutputCollector.getValue("1").size());
                return;
            } else {
                dummyOutputCollector.collect((DummyOutputCollector) text3, text);
                text = new Text();
                text2 = new Text();
            }
        }
    }
}
