View Javadoc
1   /*
2    * Copyright (c) 2003, The JUNG Authors
3    *
4    * All rights reserved.
5    *
6    * This software is open-source under the BSD license; see either
7    * "license.txt" or
8    * https://github.com/jrtom/jung/blob/master/LICENSE for a description.
9    */
10  /*
11   * Created on Aug 9, 2004
12   *
13   */
14  package edu.uci.ics.jung.algorithms.util;
15  
16  import java.util.ArrayList;
17  import java.util.Arrays;
18  import java.util.Collection;
19  import java.util.HashMap;
20  import java.util.HashSet;
21  import java.util.Iterator;
22  import java.util.Map;
23  import java.util.Random;
24  import java.util.Set;
25  
26  
27  
28  /**
29   * Groups items into a specified number of clusters, based on their proximity in
30   * d-dimensional space, using the k-means algorithm. Calls to
31   * <code>cluster</code> will terminate when either of the two following
32   * conditions is true:
33   * <ul>
34   * <li>the number of iterations is &gt; <code>max_iterations</code> 
35   * <li>none of the centroids has moved as much as <code>convergence_threshold</code>
36   * since the previous iteration
37   * </ul>
38   * 
39   * @author Joshua O'Madadhain
40   */
41  public class KMeansClusterer<T>
42  {
43      protected int max_iterations;
44      protected double convergence_threshold;
45      protected Random rand;
46  
47      /**
48       * Creates an instance which will terminate when either the maximum number of 
49       * iterations has been reached, or all changes are smaller than the convergence threshold.
50       * @param max_iterations the maximum number of iterations to employ
51       * @param convergence_threshold the smallest change we want to track
52       */
53      public KMeansClusterer(int max_iterations, double convergence_threshold)
54      {
55          this.max_iterations = max_iterations;
56          this.convergence_threshold = convergence_threshold;
57          this.rand = new Random();
58      }
59  
60      /**
61       * Creates an instance with max iterations of 100 and convergence threshold
62       * of 0.001.
63       */
64      public KMeansClusterer()
65      {
66          this(100, 0.001);
67      }
68  
69      /**
70       * @return the maximum number of iterations
71       */
72      public int getMaxIterations()
73      {
74          return max_iterations;
75      }
76  
77      /**
78       * @param max_iterations the maximum number of iterations
79       */
80      public void setMaxIterations(int max_iterations)
81      {
82          if (max_iterations < 0)
83              throw new IllegalArgumentException("max iterations must be >= 0");
84  
85          this.max_iterations = max_iterations;
86      }
87  
88      /**
89       * @return the convergence threshold
90       */
91      public double getConvergenceThreshold()
92      {
93          return convergence_threshold;
94      }
95  
96      /**
97       * @param convergence_threshold the convergence threshold
98       */
99      public void setConvergenceThreshold(double convergence_threshold)
100     {
101         if (convergence_threshold <= 0)
102             throw new IllegalArgumentException("convergence threshold " +
103                 "must be > 0");
104 
105         this.convergence_threshold = convergence_threshold;
106     }
107 
108     /**
109      * Returns a <code>Collection</code> of clusters, where each cluster is
110      * represented as a <code>Map</code> of <code>Objects</code> to locations
111      * in d-dimensional space.
112      * @param object_locations  a map of the items to cluster, to
113      * <code>double</code> arrays that specify their locations in d-dimensional space.
114      * @param num_clusters  the number of clusters to create
115      * @return a clustering of the input objects in d-dimensional space
116      * @throws NotEnoughClustersException if {@code num_clusters} is larger than the number of
117      *     distinct points in object_locations
118      */
119     @SuppressWarnings("unchecked")
120     public Collection<Map<T, double[]>> cluster(Map<T, double[]> object_locations, int num_clusters)
121     {
122         if (object_locations == null || object_locations.isEmpty())
123             throw new IllegalArgumentException("'objects' must be non-empty");
124 
125         if (num_clusters < 2 || num_clusters > object_locations.size())
126             throw new IllegalArgumentException("number of clusters " +
127                 "must be >= 2 and <= number of objects (" +
128                 object_locations.size() + ")");
129 
130 
131         Set<double[]> centroids = new HashSet<double[]>();
132 
133         Object[] obj_array = object_locations.keySet().toArray();
134         Set<T> tried = new HashSet<T>();
135 
136         // create the specified number of clusters
137         while (centroids.size() < num_clusters && tried.size() < object_locations.size())
138         {
139             T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)];
140             tried.add(o);
141             double[] mean_value = object_locations.get(o);
142             boolean duplicate = false;
143             for (double[] cur : centroids)
144             {
145                 if (Arrays.equals(mean_value, cur))
146                     duplicate = true;
147             }
148             if (!duplicate)
149                 centroids.add(mean_value);
150         }
151 
152         if (tried.size() >= object_locations.size())
153             throw new NotEnoughClustersException();
154 
155         // put items in their initial clusters
156         Map<double[], Map<T, double[]>> clusterMap = assignToClusters(object_locations, centroids);
157 
158         // keep reconstituting clusters until either
159         // (a) membership is stable, or
160         // (b) number of iterations passes max_iterations, or
161         // (c) max movement of any centroid is <= convergence_threshold
162         int iterations = 0;
163         double max_movement = Double.POSITIVE_INFINITY;
164         while (iterations++ < max_iterations && max_movement > convergence_threshold)
165         {
166             max_movement = 0;
167             Set<double[]> new_centroids = new HashSet<double[]>();
168             // calculate new mean for each cluster
169             for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap.entrySet())
170             {
171                 double[] centroid = entry.getKey();
172                 Map<T, double[]> elements = entry.getValue();
173                 ArrayList<double[]> locations = new ArrayList<double[]>(elements.values());
174 
175                 double[] mean = DiscreteDistribution.mean(locations);
176                 max_movement = Math.max(max_movement,
177                     Math.sqrt(DiscreteDistribution.squaredError(centroid, mean)));
178                 new_centroids.add(mean);
179             }
180 
181             // TODO: check membership of clusters: have they changed?
182 
183             // regenerate cluster membership based on means
184             clusterMap = assignToClusters(object_locations, new_centroids);
185         }
186         return clusterMap.values();
187     }
188 
189     /**
190      * Assigns each object to the cluster whose centroid is closest to the
191      * object.
192      * @param object_locations  a map of objects to locations
193      * @param centroids         the centroids of the clusters to be formed
194      * @return a map of objects to assigned clusters
195      */
196     protected Map<double[], Map<T, double[]>> assignToClusters(Map<T, double[]> object_locations, Set<double[]> centroids)
197     {
198         Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>();
199         for (double[] centroid : centroids)
200             clusterMap.put(centroid, new HashMap<T, double[]>());
201 
202         for (Map.Entry<T, double[]> object_location : object_locations.entrySet())
203         {
204             T object = object_location.getKey();
205             double[] location = object_location.getValue();
206 
207             // find the cluster with the closest centroid
208             Iterator<double[]> c_iter = centroids.iterator();
209             double[] closest = c_iter.next();
210             double distance = DiscreteDistribution.squaredError(location, closest);
211 
212             while (c_iter.hasNext())
213             {
214                 double[] centroid = c_iter.next();
215                 double dist_cur = DiscreteDistribution.squaredError(location, centroid);
216                 if (dist_cur < distance)
217                 {
218                     distance = dist_cur;
219                     closest = centroid;
220                 }
221             }
222             clusterMap.get(closest).put(object, location);
223         }
224 
225         return clusterMap;
226     }
227 
228     /**
229      * Sets the seed used by the internal random number generator.
230      * Enables consistent outputs.
231      * @param random_seed the random seed to use
232      */
233     public void setSeed(int random_seed)
234     {
235         this.rand = new Random(random_seed);
236     }
237 
238     /**
239      * An exception that indicates that the specified data points cannot be
240      * clustered into the number of clusters requested by the user.
241      * This will happen if and only if there are fewer distinct points than
242      * requested clusters.  (If there are fewer total data points than
243      * requested clusters, <code>IllegalArgumentException</code> will be thrown.)
244      *
245      * @author Joshua O'Madadhain
246      */
247     @SuppressWarnings("serial")
248     public static class NotEnoughClustersException extends RuntimeException
249     {
250         @Override
251         public String getMessage()
252         {
253             return "Not enough distinct points in the input data set to form " +
254                     "the requested number of clusters";
255         }
256     }
257 }