001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math3.stat.clustering;
019
020 import java.util.ArrayList;
021 import java.util.Collection;
022 import java.util.Collections;
023 import java.util.List;
024 import java.util.Random;
025
026 import org.apache.commons.math3.exception.ConvergenceException;
027 import org.apache.commons.math3.exception.MathIllegalArgumentException;
028 import org.apache.commons.math3.exception.NumberIsTooSmallException;
029 import org.apache.commons.math3.exception.util.LocalizedFormats;
030 import org.apache.commons.math3.stat.descriptive.moment.Variance;
031 import org.apache.commons.math3.util.MathUtils;
032
033 /**
034 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035 * @param <T> type of the points to cluster
036 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037 * @version $Id: KMeansPlusPlusClusterer.java 1416643 2012-12-03 19:37:14Z tn $
038 * @since 2.0
039 */
040 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
041
042 /** Strategies to use for replacing an empty cluster. */
043 public static enum EmptyClusterStrategy {
044
045 /** Split the cluster with largest distance variance. */
046 LARGEST_VARIANCE,
047
048 /** Split the cluster with largest number of points. */
049 LARGEST_POINTS_NUMBER,
050
051 /** Create a cluster around the point farthest from its centroid. */
052 FARTHEST_POINT,
053
054 /** Generate an error. */
055 ERROR
056
057 }
058
059 /** Random generator for choosing initial centers. */
060 private final Random random;
061
062 /** Selected strategy for empty clusters. */
063 private final EmptyClusterStrategy emptyStrategy;
064
065 /** Build a clusterer.
066 * <p>
067 * The default strategy for handling empty clusters that may appear during
068 * algorithm iterations is to split the cluster with largest distance variance.
069 * </p>
070 * @param random random generator to use for choosing initial centers
071 */
072 public KMeansPlusPlusClusterer(final Random random) {
073 this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
074 }
075
076 /** Build a clusterer.
077 * @param random random generator to use for choosing initial centers
078 * @param emptyStrategy strategy to use for handling empty clusters that
079 * may appear during algorithm iterations
080 * @since 2.2
081 */
082 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
083 this.random = random;
084 this.emptyStrategy = emptyStrategy;
085 }
086
087 /**
088 * Runs the K-means++ clustering algorithm.
089 *
090 * @param points the points to cluster
091 * @param k the number of clusters to split the data into
092 * @param numTrials number of trial runs
093 * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
094 * for at each trial run. If negative, no maximum will be used
095 * @return a list of clusters containing the points
096 * @throws MathIllegalArgumentException if the data points are null or the number
097 * of clusters is larger than the number of data points
098 * @throws ConvergenceException if an empty cluster is encountered and the
099 * {@link #emptyStrategy} is set to {@code ERROR}
100 */
101 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
102 int numTrials, int maxIterationsPerTrial)
103 throws MathIllegalArgumentException, ConvergenceException {
104
105 // at first, we have not found any clusters list yet
106 List<Cluster<T>> best = null;
107 double bestVarianceSum = Double.POSITIVE_INFINITY;
108
109 // do several clustering trials
110 for (int i = 0; i < numTrials; ++i) {
111
112 // compute a clusters list
113 List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
114
115 // compute the variance of the current list
116 double varianceSum = 0.0;
117 for (final Cluster<T> cluster : clusters) {
118 if (!cluster.getPoints().isEmpty()) {
119
120 // compute the distance variance of the current cluster
121 final T center = cluster.getCenter();
122 final Variance stat = new Variance();
123 for (final T point : cluster.getPoints()) {
124 stat.increment(point.distanceFrom(center));
125 }
126 varianceSum += stat.getResult();
127
128 }
129 }
130
131 if (varianceSum <= bestVarianceSum) {
132 // this one is the best we have found so far, remember it
133 best = clusters;
134 bestVarianceSum = varianceSum;
135 }
136
137 }
138
139 // return the best clusters list found
140 return best;
141
142 }
143
144 /**
145 * Runs the K-means++ clustering algorithm.
146 *
147 * @param points the points to cluster
148 * @param k the number of clusters to split the data into
149 * @param maxIterations the maximum number of iterations to run the algorithm
150 * for. If negative, no maximum will be used
151 * @return a list of clusters containing the points
152 * @throws MathIllegalArgumentException if the data points are null or the number
153 * of clusters is larger than the number of data points
154 * @throws ConvergenceException if an empty cluster is encountered and the
155 * {@link #emptyStrategy} is set to {@code ERROR}
156 */
157 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
158 final int maxIterations)
159 throws MathIllegalArgumentException, ConvergenceException {
160
161 // sanity checks
162 MathUtils.checkNotNull(points);
163
164 // number of clusters has to be smaller or equal the number of data points
165 if (points.size() < k) {
166 throw new NumberIsTooSmallException(points.size(), k, false);
167 }
168
169 // create the initial clusters
170 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
171
172 // create an array containing the latest assignment of a point to a cluster
173 // no need to initialize the array, as it will be filled with the first assignment
174 int[] assignments = new int[points.size()];
175 assignPointsToClusters(clusters, points, assignments);
176
177 // iterate through updating the centers until we're done
178 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
179 for (int count = 0; count < max; count++) {
180 boolean emptyCluster = false;
181 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
182 for (final Cluster<T> cluster : clusters) {
183 final T newCenter;
184 if (cluster.getPoints().isEmpty()) {
185 switch (emptyStrategy) {
186 case LARGEST_VARIANCE :
187 newCenter = getPointFromLargestVarianceCluster(clusters);
188 break;
189 case LARGEST_POINTS_NUMBER :
190 newCenter = getPointFromLargestNumberCluster(clusters);
191 break;
192 case FARTHEST_POINT :
193 newCenter = getFarthestPoint(clusters);
194 break;
195 default :
196 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
197 }
198 emptyCluster = true;
199 } else {
200 newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
201 }
202 newClusters.add(new Cluster<T>(newCenter));
203 }
204 int changes = assignPointsToClusters(newClusters, points, assignments);
205 clusters = newClusters;
206
207 // if there were no more changes in the point-to-cluster assignment
208 // and there are no empty clusters left, return the current clusters
209 if (changes == 0 && !emptyCluster) {
210 return clusters;
211 }
212 }
213 return clusters;
214 }
215
216 /**
217 * Adds the given points to the closest {@link Cluster}.
218 *
219 * @param <T> type of the points to cluster
220 * @param clusters the {@link Cluster}s to add the points to
221 * @param points the points to add to the given {@link Cluster}s
222 * @param assignments points assignments to clusters
223 * @return the number of points assigned to different clusters as the iteration before
224 */
225 private static <T extends Clusterable<T>> int
226 assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
227 final int[] assignments) {
228 int assignedDifferently = 0;
229 int pointIndex = 0;
230 for (final T p : points) {
231 int clusterIndex = getNearestCluster(clusters, p);
232 if (clusterIndex != assignments[pointIndex]) {
233 assignedDifferently++;
234 }
235
236 Cluster<T> cluster = clusters.get(clusterIndex);
237 cluster.addPoint(p);
238 assignments[pointIndex++] = clusterIndex;
239 }
240
241 return assignedDifferently;
242 }
243
244 /**
245 * Use K-means++ to choose the initial centers.
246 *
247 * @param <T> type of the points to cluster
248 * @param points the points to choose the initial centers from
249 * @param k the number of centers to choose
250 * @param random random generator to use
251 * @return the initial centers
252 */
253 private static <T extends Clusterable<T>> List<Cluster<T>>
254 chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
255
256 // Convert to list for indexed access. Make it unmodifiable, since removal of items
257 // would screw up the logic of this method.
258 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
259
260 // The number of points in the list.
261 final int numPoints = pointList.size();
262
263 // Set the corresponding element in this array to indicate when
264 // elements of pointList are no longer available.
265 final boolean[] taken = new boolean[numPoints];
266
267 // The resulting list of initial centers.
268 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
269
270 // Choose one center uniformly at random from among the data points.
271 final int firstPointIndex = random.nextInt(numPoints);
272
273 final T firstPoint = pointList.get(firstPointIndex);
274
275 resultSet.add(new Cluster<T>(firstPoint));
276
277 // Must mark it as taken
278 taken[firstPointIndex] = true;
279
280 // To keep track of the minimum distance squared of elements of
281 // pointList to elements of resultSet.
282 final double[] minDistSquared = new double[numPoints];
283
284 // Initialize the elements. Since the only point in resultSet is firstPoint,
285 // this is very easy.
286 for (int i = 0; i < numPoints; i++) {
287 if (i != firstPointIndex) { // That point isn't considered
288 double d = firstPoint.distanceFrom(pointList.get(i));
289 minDistSquared[i] = d*d;
290 }
291 }
292
293 while (resultSet.size() < k) {
294
295 // Sum up the squared distances for the points in pointList not
296 // already taken.
297 double distSqSum = 0.0;
298
299 for (int i = 0; i < numPoints; i++) {
300 if (!taken[i]) {
301 distSqSum += minDistSquared[i];
302 }
303 }
304
305 // Add one new data point as a center. Each point x is chosen with
306 // probability proportional to D(x)2
307 final double r = random.nextDouble() * distSqSum;
308
309 // The index of the next point to be added to the resultSet.
310 int nextPointIndex = -1;
311
312 // Sum through the squared min distances again, stopping when
313 // sum >= r.
314 double sum = 0.0;
315 for (int i = 0; i < numPoints; i++) {
316 if (!taken[i]) {
317 sum += minDistSquared[i];
318 if (sum >= r) {
319 nextPointIndex = i;
320 break;
321 }
322 }
323 }
324
325 // If it's not set to >= 0, the point wasn't found in the previous
326 // for loop, probably because distances are extremely small. Just pick
327 // the last available point.
328 if (nextPointIndex == -1) {
329 for (int i = numPoints - 1; i >= 0; i--) {
330 if (!taken[i]) {
331 nextPointIndex = i;
332 break;
333 }
334 }
335 }
336
337 // We found one.
338 if (nextPointIndex >= 0) {
339
340 final T p = pointList.get(nextPointIndex);
341
342 resultSet.add(new Cluster<T> (p));
343
344 // Mark it as taken.
345 taken[nextPointIndex] = true;
346
347 if (resultSet.size() < k) {
348 // Now update elements of minDistSquared. We only have to compute
349 // the distance to the new center to do this.
350 for (int j = 0; j < numPoints; j++) {
351 // Only have to worry about the points still not taken.
352 if (!taken[j]) {
353 double d = p.distanceFrom(pointList.get(j));
354 double d2 = d * d;
355 if (d2 < minDistSquared[j]) {
356 minDistSquared[j] = d2;
357 }
358 }
359 }
360 }
361
362 } else {
363 // None found --
364 // Break from the while loop to prevent
365 // an infinite loop.
366 break;
367 }
368 }
369
370 return resultSet;
371 }
372
373 /**
374 * Get a random point from the {@link Cluster} with the largest distance variance.
375 *
376 * @param clusters the {@link Cluster}s to search
377 * @return a random point from the selected cluster
378 * @throws ConvergenceException if clusters are all empty
379 */
380 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
381 throws ConvergenceException {
382
383 double maxVariance = Double.NEGATIVE_INFINITY;
384 Cluster<T> selected = null;
385 for (final Cluster<T> cluster : clusters) {
386 if (!cluster.getPoints().isEmpty()) {
387
388 // compute the distance variance of the current cluster
389 final T center = cluster.getCenter();
390 final Variance stat = new Variance();
391 for (final T point : cluster.getPoints()) {
392 stat.increment(point.distanceFrom(center));
393 }
394 final double variance = stat.getResult();
395
396 // select the cluster with the largest variance
397 if (variance > maxVariance) {
398 maxVariance = variance;
399 selected = cluster;
400 }
401
402 }
403 }
404
405 // did we find at least one non-empty cluster ?
406 if (selected == null) {
407 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
408 }
409
410 // extract a random point from the cluster
411 final List<T> selectedPoints = selected.getPoints();
412 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
413
414 }
415
416 /**
417 * Get a random point from the {@link Cluster} with the largest number of points
418 *
419 * @param clusters the {@link Cluster}s to search
420 * @return a random point from the selected cluster
421 * @throws ConvergenceException if clusters are all empty
422 */
423 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
424
425 int maxNumber = 0;
426 Cluster<T> selected = null;
427 for (final Cluster<T> cluster : clusters) {
428
429 // get the number of points of the current cluster
430 final int number = cluster.getPoints().size();
431
432 // select the cluster with the largest number of points
433 if (number > maxNumber) {
434 maxNumber = number;
435 selected = cluster;
436 }
437
438 }
439
440 // did we find at least one non-empty cluster ?
441 if (selected == null) {
442 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
443 }
444
445 // extract a random point from the cluster
446 final List<T> selectedPoints = selected.getPoints();
447 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
448
449 }
450
451 /**
452 * Get the point farthest to its cluster center
453 *
454 * @param clusters the {@link Cluster}s to search
455 * @return point farthest to its cluster center
456 * @throws ConvergenceException if clusters are all empty
457 */
458 private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
459
460 double maxDistance = Double.NEGATIVE_INFINITY;
461 Cluster<T> selectedCluster = null;
462 int selectedPoint = -1;
463 for (final Cluster<T> cluster : clusters) {
464
465 // get the farthest point
466 final T center = cluster.getCenter();
467 final List<T> points = cluster.getPoints();
468 for (int i = 0; i < points.size(); ++i) {
469 final double distance = points.get(i).distanceFrom(center);
470 if (distance > maxDistance) {
471 maxDistance = distance;
472 selectedCluster = cluster;
473 selectedPoint = i;
474 }
475 }
476
477 }
478
479 // did we find at least one non-empty cluster ?
480 if (selectedCluster == null) {
481 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
482 }
483
484 return selectedCluster.getPoints().remove(selectedPoint);
485
486 }
487
488 /**
489 * Returns the nearest {@link Cluster} to the given point
490 *
491 * @param <T> type of the points to cluster
492 * @param clusters the {@link Cluster}s to search
493 * @param point the point to find the nearest {@link Cluster} for
494 * @return the index of the nearest {@link Cluster} to the given point
495 */
496 private static <T extends Clusterable<T>> int
497 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
498 double minDistance = Double.MAX_VALUE;
499 int clusterIndex = 0;
500 int minCluster = 0;
501 for (final Cluster<T> c : clusters) {
502 final double distance = point.distanceFrom(c.getCenter());
503 if (distance < minDistance) {
504 minDistance = distance;
505 minCluster = clusterIndex;
506 }
507 clusterIndex++;
508 }
509 return minCluster;
510 }
511
512 }