package org.apache.mahout.clustering.kmeans;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Closeables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.CanopyDriver;
import org.apache.mahout.common.DummyOutputCollector;
import org.apache.mahout.common.DummyRecordWriter;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/kmeans/TestKmeansClustering.class */
public final class TestKmeansClustering extends MahoutTestCase {
    public static final double[][] REFERENCE = {new double[]{1.0d, 1.0d}, new double[]{2.0d, 1.0d}, new double[]{1.0d, 2.0d}, new double[]{2.0d, 2.0d}, new double[]{3.0d, 3.0d}, new double[]{4.0d, 4.0d}, new double[]{5.0d, 4.0d}, new double[]{4.0d, 5.0d}, new double[]{5.0d, 5.0d}};
    private static final int[][] EXPECTED_NUM_POINTS = {new int[]{9}, new int[]{4, 5}, new int[]{4, 4, 1}, new int[]{1, 2, 1, 5}, new int[]{1, 1, 1, 2, 4}, new int[]{1, 1, 1, 1, 1, 4}, new int[]{1, 1, 1, 1, 1, 2, 2}, new int[]{1, 1, 1, 1, 1, 1, 2, 1}, new int[]{1, 1, 1, 1, 1, 1, 1, 1, 1}};
    private FileSystem fs;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.fs = FileSystem.get(new Configuration());
    }

    public static List<VectorWritable> getPointsWritable(double[][] dArr) {
        ArrayList newArrayList = Lists.newArrayList();
        for (double[] dArr2 : dArr) {
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(dArr2.length);
            randomAccessSparseVector.assign(dArr2);
            newArrayList.add(new VectorWritable(randomAccessSparseVector));
        }
        return newArrayList;
    }

    public static List<Vector> getPoints(double[][] dArr) {
        ArrayList newArrayList = Lists.newArrayList();
        for (double[] dArr2 : dArr) {
            SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(dArr2.length);
            sequentialAccessSparseVector.assign(dArr2);
            newArrayList.add(sequentialAccessSparseVector);
        }
        return newArrayList;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Test
    public void testRunKMeansIterationConvergesInOneRunWithGivenDistanceThreshold() {
        List<Vector> points = getPoints(new double[]{new double[]{0.0d, 0.0d}, new double[]{0.0d, 0.25d}, new double[]{0.0d, 0.75d}, new double[]{0.0d, 1.0d}});
        ManhattanDistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
        List asList = Arrays.asList(new Cluster(points.get(0), 0, manhattanDistanceMeasure), new Cluster(points.get(3), 3, manhattanDistanceMeasure));
        boolean runKMeansIteration = KMeansClusterer.runKMeansIteration(points, asList, manhattanDistanceMeasure, 0.25d);
        Vector center = ((Cluster) asList.get(0)).getCenter();
        assertEquals(0.0d, center.get(0), 1.0E-6d);
        assertEquals(0.125d, center.get(1), 1.0E-6d);
        Vector center2 = ((Cluster) asList.get(1)).getCenter();
        assertEquals(0.0d, center2.get(0), 1.0E-6d);
        assertEquals(0.875d, center2.get(1), 1.0E-6d);
        assertTrue("KMeans iteration should be converged after a single run", runKMeansIteration);
    }

    @Test
    public void testReferenceImplementation() throws Exception {
        List<Vector> points = getPoints(REFERENCE);
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        for (int i = 0; i < points.size(); i++) {
            System.out.println("Test k=" + (i + 1) + ':');
            ArrayList newArrayList = Lists.newArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                newArrayList.add(new Cluster(points.get(i2), i2, euclideanDistanceMeasure));
            }
            List clusterPoints = KMeansClusterer.clusterPoints(points, newArrayList, euclideanDistanceMeasure, 10, 0.001d);
            List list = (List) clusterPoints.get(clusterPoints.size() - 1);
            for (int i3 = 0; i3 < list.size(); i3++) {
                AbstractCluster abstractCluster = (AbstractCluster) list.get(i3);
                System.out.println(abstractCluster.asFormatString((String[]) null));
                assertEquals("Cluster " + i3 + " test " + (i + 1), EXPECTED_NUM_POINTS[i][i3], abstractCluster.getNumPoints());
            }
        }
    }

    private static Map<String, Cluster> loadClusterMap(Iterable<Cluster> iterable) {
        HashMap newHashMap = Maps.newHashMap();
        for (Cluster cluster : iterable) {
            newHashMap.put(cluster.getIdentifier(), cluster);
        }
        return newHashMap;
    }

    @Test
    public void testKMeansMapper() throws Exception {
        KMeansMapper kMeansMapper = new KMeansMapper();
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        Configuration configuration = new Configuration();
        configuration.set("org.apache.mahout.clustering.kmeans.measure", euclideanDistanceMeasure.getClass().getName());
        configuration.set("org.apache.mahout.clustering.kmeans.convergence", "0.001");
        configuration.set("org.apache.mahout.clustering.kmeans.path", "");
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        for (int i = 0; i < pointsWritable.size(); i++) {
            DummyRecordWriter dummyRecordWriter = new DummyRecordWriter();
            Mapper.Context build = DummyRecordWriter.build(kMeansMapper, configuration, dummyRecordWriter);
            ArrayList newArrayList = Lists.newArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                Cluster cluster = new Cluster(pointsWritable.get(i2).get(), i2, euclideanDistanceMeasure);
                cluster.observe(cluster.getCenter(), 1.0d);
                newArrayList.add(cluster);
            }
            kMeansMapper.setup(newArrayList, euclideanDistanceMeasure);
            Iterator<VectorWritable> it = pointsWritable.iterator();
            while (it.hasNext()) {
                kMeansMapper.map(new Text(), it.next(), build);
            }
            assertEquals("Number of map results", i + 1, dummyRecordWriter.getData().size());
            Map<String, Cluster> loadClusterMap = loadClusterMap(newArrayList);
            for (Text text : dummyRecordWriter.getKeys()) {
                AbstractCluster abstractCluster = loadClusterMap.get(text.toString());
                for (ClusterObservations clusterObservations : dummyRecordWriter.getValue(text)) {
                    double distance = euclideanDistanceMeasure.distance(abstractCluster.getCenter(), clusterObservations.getS1());
                    Iterator it2 = newArrayList.iterator();
                    while (it2.hasNext()) {
                        assertTrue("distance error", distance <= euclideanDistanceMeasure.distance(clusterObservations.getS1(), ((Cluster) it2.next()).getCenter()));
                    }
                }
            }
        }
    }

    @Test
    public void testKMeansCombiner() throws Exception {
        KMeansMapper kMeansMapper = new KMeansMapper();
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        Configuration configuration = new Configuration();
        configuration.set("org.apache.mahout.clustering.kmeans.measure", euclideanDistanceMeasure.getClass().getName());
        configuration.set("org.apache.mahout.clustering.kmeans.convergence", "0.001");
        configuration.set("org.apache.mahout.clustering.kmeans.path", "");
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        for (int i = 0; i < pointsWritable.size(); i++) {
            DummyRecordWriter dummyRecordWriter = new DummyRecordWriter();
            Mapper.Context build = DummyRecordWriter.build(kMeansMapper, configuration, dummyRecordWriter);
            ArrayList newArrayList = Lists.newArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                Cluster cluster = new Cluster(pointsWritable.get(i2).get(), i2, euclideanDistanceMeasure);
                cluster.observe(cluster.getCenter(), 1.0d);
                newArrayList.add(cluster);
            }
            kMeansMapper.setup(newArrayList, euclideanDistanceMeasure);
            Iterator<VectorWritable> it = pointsWritable.iterator();
            while (it.hasNext()) {
                kMeansMapper.map(new Text(), it.next(), build);
            }
            KMeansCombiner kMeansCombiner = new KMeansCombiner();
            DummyRecordWriter dummyRecordWriter2 = new DummyRecordWriter();
            Reducer.Context build2 = DummyRecordWriter.build(kMeansCombiner, configuration, dummyRecordWriter2, Text.class, ClusterObservations.class);
            for (Text text : dummyRecordWriter.getKeys()) {
                kMeansCombiner.reduce(new Text(text), dummyRecordWriter.getValue(text), build2);
            }
            assertEquals("Number of map results", i + 1, dummyRecordWriter2.getData().size());
            int i3 = 0;
            Vector denseVector = new DenseVector(2);
            Iterator it2 = dummyRecordWriter2.getKeys().iterator();
            while (it2.hasNext()) {
                List value = dummyRecordWriter2.getValue((Text) it2.next());
                assertEquals("too many values", 1L, value.size());
                ClusterObservations clusterObservations = (ClusterObservations) value.get(0);
                i3 += (int) clusterObservations.getS0();
                denseVector = denseVector.plus(clusterObservations.getS1());
            }
            assertEquals("total points", 9L, i3);
            assertEquals("point total[0]", 27L, (int) denseVector.get(0));
            assertEquals("point total[1]", 27L, (int) denseVector.get(1));
        }
    }

    @Test
    public void testKMeansReducer() throws Exception {
        KMeansMapper kMeansMapper = new KMeansMapper();
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        Configuration configuration = new Configuration();
        configuration.set("org.apache.mahout.clustering.kmeans.measure", euclideanDistanceMeasure.getClass().getName());
        configuration.set("org.apache.mahout.clustering.kmeans.convergence", "0.001");
        configuration.set("org.apache.mahout.clustering.kmeans.path", "");
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        for (int i = 0; i < pointsWritable.size(); i++) {
            System.out.println("K = " + i);
            DummyRecordWriter dummyRecordWriter = new DummyRecordWriter();
            Mapper.Context build = DummyRecordWriter.build(kMeansMapper, configuration, dummyRecordWriter);
            ArrayList newArrayList = Lists.newArrayList();
            for (int i2 = 0; i2 < i + 1; i2++) {
                newArrayList.add(new Cluster(pointsWritable.get(i2).get(), i2, euclideanDistanceMeasure));
            }
            kMeansMapper.setup(newArrayList, new EuclideanDistanceMeasure());
            Iterator<VectorWritable> it = pointsWritable.iterator();
            while (it.hasNext()) {
                kMeansMapper.map(new Text(), it.next(), build);
            }
            KMeansCombiner kMeansCombiner = new KMeansCombiner();
            DummyRecordWriter dummyRecordWriter2 = new DummyRecordWriter();
            Reducer.Context build2 = DummyRecordWriter.build(kMeansCombiner, configuration, dummyRecordWriter2, Text.class, ClusterObservations.class);
            for (Text text : dummyRecordWriter.getKeys()) {
                kMeansCombiner.reduce(new Text(text), dummyRecordWriter.getValue(text), build2);
            }
            KMeansReducer kMeansReducer = new KMeansReducer();
            kMeansReducer.setup(newArrayList, euclideanDistanceMeasure);
            DummyRecordWriter dummyRecordWriter3 = new DummyRecordWriter();
            Reducer.Context build3 = DummyRecordWriter.build(kMeansReducer, configuration, dummyRecordWriter3, Text.class, ClusterObservations.class);
            for (Text text2 : dummyRecordWriter2.getKeys()) {
                kMeansReducer.reduce(new Text(text2), dummyRecordWriter2.getValue(text2), build3);
            }
            assertEquals("Number of map results", i + 1, dummyRecordWriter3.getData().size());
            ArrayList<Cluster> newArrayList2 = Lists.newArrayList();
            for (int i3 = 0; i3 < i + 1; i3++) {
                newArrayList2.add(new Cluster(pointsWritable.get(i3).get(), i3, euclideanDistanceMeasure));
            }
            ArrayList newArrayList3 = Lists.newArrayList();
            Iterator<VectorWritable> it2 = pointsWritable.iterator();
            while (it2.hasNext()) {
                newArrayList3.add(it2.next().get());
            }
            boolean runKMeansIteration = KMeansClusterer.runKMeansIteration(newArrayList3, newArrayList2, euclideanDistanceMeasure, 0.001d);
            if (i == 8) {
                assertTrue("not converged? " + i, runKMeansIteration);
            } else {
                assertFalse("converged? " + i, runKMeansIteration);
            }
            boolean z = true;
            for (Cluster cluster : newArrayList2) {
                Cluster cluster2 = (Cluster) dummyRecordWriter3.getValue(new Text(cluster.getIdentifier())).get(0);
                z = z && cluster2.isConverged();
                cluster2.computeParameters();
                assertEquals(cluster.getCenter(), cluster2.getCenter());
            }
            if (i == 8) {
                assertTrue("not converged? " + i, z);
            } else {
                assertFalse("converged? " + i, z);
            }
        }
    }

    @Test
    public void testKMeansSeqJob() throws Exception {
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        Path testTempDirPath = getTestTempDirPath("points");
        Path testTempDirPath2 = getTestTempDirPath("clusters");
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(testTempDirPath, "file2"), this.fs, configuration);
        for (int i = 1; i < pointsWritable.size(); i++) {
            System.out.println("testKMeansMRJob k= " + i);
            Path path = new Path(testTempDirPath2, "part-00000");
            SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(path.toUri(), configuration), configuration, path, Text.class, Cluster.class);
            for (int i2 = 0; i2 < i + 1; i2++) {
                try {
                    Cluster cluster = new Cluster(pointsWritable.get(i2).get(), i2, euclideanDistanceMeasure);
                    cluster.observe(cluster.getCenter(), 1.0d);
                    writer.append(new Text(cluster.getIdentifier()), cluster);
                } finally {
                    Closeables.closeQuietly(writer);
                }
            }
            Path testTempDirPath3 = getTestTempDirPath("output");
            new KMeansDriver().run(new String[]{optKey("input"), testTempDirPath.toString(), optKey("clusters"), testTempDirPath2.toString(), optKey("output"), testTempDirPath3.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("convergenceDelta"), "0.001", optKey("maxIter"), "2", optKey("clustering"), optKey("overwrite"), optKey("method"), "sequential"});
            Path path2 = new Path(testTempDirPath3, "clusteredPoints");
            int[] iArr = EXPECTED_NUM_POINTS[i];
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            Iterator it = new SequenceFileIterable(new Path(path2, "part-m-0"), configuration).iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
            }
            assertEquals("clusters[" + i + ']', iArr.length, dummyOutputCollector.getKeys().size());
        }
    }

    @Test
    public void testKMeansMRJob() throws Exception {
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        Path testTempDirPath = getTestTempDirPath("points");
        Path testTempDirPath2 = getTestTempDirPath("clusters");
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(testTempDirPath, "file2"), this.fs, configuration);
        for (int i = 1; i < pointsWritable.size(); i++) {
            System.out.println("testKMeansMRJob k= " + i);
            Path path = new Path(testTempDirPath2, "part-00000");
            SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(path.toUri(), configuration), configuration, path, Text.class, Cluster.class);
            for (int i2 = 0; i2 < i + 1; i2++) {
                try {
                    Cluster cluster = new Cluster(pointsWritable.get(i2).get(), i2, euclideanDistanceMeasure);
                    cluster.observe(cluster.getCenter(), 1.0d);
                    writer.append(new Text(cluster.getIdentifier()), cluster);
                } finally {
                    Closeables.closeQuietly(writer);
                }
            }
            Path testTempDirPath3 = getTestTempDirPath("output");
            ToolRunner.run(new Configuration(), new KMeansDriver(), new String[]{optKey("input"), testTempDirPath.toString(), optKey("clusters"), testTempDirPath2.toString(), optKey("output"), testTempDirPath3.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("convergenceDelta"), "0.001", optKey("maxIter"), "2", optKey("clustering"), optKey("overwrite")});
            Path path2 = new Path(testTempDirPath3, "clusteredPoints");
            int[] iArr = EXPECTED_NUM_POINTS[i];
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            Iterator it = new SequenceFileIterable(new Path(path2, "part-m-00000"), configuration).iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
            }
            if (i == 2) {
                assertEquals("clusters[" + i + ']', iArr.length - 1, dummyOutputCollector.getKeys().size());
            } else {
                assertEquals("clusters[" + i + ']', iArr.length, dummyOutputCollector.getKeys().size());
            }
        }
    }

    @Test
    public void testKMeansWithCanopyClusterInput() throws Exception {
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        Path testTempDirPath = getTestTempDirPath("points");
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(testTempDirPath, "file2"), this.fs, configuration);
        Path testTempDirPath2 = getTestTempDirPath("output");
        CanopyDriver.run(configuration, testTempDirPath, testTempDirPath2, new ManhattanDistanceMeasure(), 3.1d, 2.1d, false, false);
        KMeansDriver.run(testTempDirPath, new Path(testTempDirPath2, "clusters-0-final"), testTempDirPath2, new EuclideanDistanceMeasure(), 0.001d, 10, true, false);
        Path path = new Path(testTempDirPath2, "clusteredPoints");
        DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
        Iterator it = new SequenceFileIterable(new Path(path, "part-m-00000"), configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
        }
        assertEquals("num points[0]", 4L, dummyOutputCollector.getValue(new IntWritable(0)).size());
        assertEquals("num points[1]", 5L, dummyOutputCollector.getValue(new IntWritable(1)).size());
    }
}
