1
2
3
4
5
6
7
8
9
10
11
12
13
14 package edu.uci.ics.jung.algorithms.util;
15
16 import java.util.ArrayList;
17 import java.util.Arrays;
18 import java.util.Collection;
19 import java.util.HashMap;
20 import java.util.HashSet;
21 import java.util.Iterator;
22 import java.util.Map;
23 import java.util.Random;
24 import java.util.Set;
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 public class KMeansClusterer<T>
42 {
43 protected int max_iterations;
44 protected double convergence_threshold;
45 protected Random rand;
46
47
48
49
50
51
52
53 public KMeansClusterer(int max_iterations, double convergence_threshold)
54 {
55 this.max_iterations = max_iterations;
56 this.convergence_threshold = convergence_threshold;
57 this.rand = new Random();
58 }
59
60
61
62
63
64 public KMeansClusterer()
65 {
66 this(100, 0.001);
67 }
68
69
70
71
72 public int getMaxIterations()
73 {
74 return max_iterations;
75 }
76
77
78
79
80 public void setMaxIterations(int max_iterations)
81 {
82 if (max_iterations < 0)
83 throw new IllegalArgumentException("max iterations must be >= 0");
84
85 this.max_iterations = max_iterations;
86 }
87
88
89
90
91 public double getConvergenceThreshold()
92 {
93 return convergence_threshold;
94 }
95
96
97
98
99 public void setConvergenceThreshold(double convergence_threshold)
100 {
101 if (convergence_threshold <= 0)
102 throw new IllegalArgumentException("convergence threshold " +
103 "must be > 0");
104
105 this.convergence_threshold = convergence_threshold;
106 }
107
108
109
110
111
112
113
114
115
116
117
118
119 @SuppressWarnings("unchecked")
120 public Collection<Map<T, double[]>> cluster(Map<T, double[]> object_locations, int num_clusters)
121 {
122 if (object_locations == null || object_locations.isEmpty())
123 throw new IllegalArgumentException("'objects' must be non-empty");
124
125 if (num_clusters < 2 || num_clusters > object_locations.size())
126 throw new IllegalArgumentException("number of clusters " +
127 "must be >= 2 and <= number of objects (" +
128 object_locations.size() + ")");
129
130
131 Set<double[]> centroids = new HashSet<double[]>();
132
133 Object[] obj_array = object_locations.keySet().toArray();
134 Set<T> tried = new HashSet<T>();
135
136
137 while (centroids.size() < num_clusters && tried.size() < object_locations.size())
138 {
139 T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)];
140 tried.add(o);
141 double[] mean_value = object_locations.get(o);
142 boolean duplicate = false;
143 for (double[] cur : centroids)
144 {
145 if (Arrays.equals(mean_value, cur))
146 duplicate = true;
147 }
148 if (!duplicate)
149 centroids.add(mean_value);
150 }
151
152 if (tried.size() >= object_locations.size())
153 throw new NotEnoughClustersException();
154
155
156 Map<double[], Map<T, double[]>> clusterMap = assignToClusters(object_locations, centroids);
157
158
159
160
161
162 int iterations = 0;
163 double max_movement = Double.POSITIVE_INFINITY;
164 while (iterations++ < max_iterations && max_movement > convergence_threshold)
165 {
166 max_movement = 0;
167 Set<double[]> new_centroids = new HashSet<double[]>();
168
169 for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap.entrySet())
170 {
171 double[] centroid = entry.getKey();
172 Map<T, double[]> elements = entry.getValue();
173 ArrayList<double[]> locations = new ArrayList<double[]>(elements.values());
174
175 double[] mean = DiscreteDistribution.mean(locations);
176 max_movement = Math.max(max_movement,
177 Math.sqrt(DiscreteDistribution.squaredError(centroid, mean)));
178 new_centroids.add(mean);
179 }
180
181
182
183
184 clusterMap = assignToClusters(object_locations, new_centroids);
185 }
186 return clusterMap.values();
187 }
188
189
190
191
192
193
194
195
196 protected Map<double[], Map<T, double[]>> assignToClusters(Map<T, double[]> object_locations, Set<double[]> centroids)
197 {
198 Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>();
199 for (double[] centroid : centroids)
200 clusterMap.put(centroid, new HashMap<T, double[]>());
201
202 for (Map.Entry<T, double[]> object_location : object_locations.entrySet())
203 {
204 T object = object_location.getKey();
205 double[] location = object_location.getValue();
206
207
208 Iterator<double[]> c_iter = centroids.iterator();
209 double[] closest = c_iter.next();
210 double distance = DiscreteDistribution.squaredError(location, closest);
211
212 while (c_iter.hasNext())
213 {
214 double[] centroid = c_iter.next();
215 double dist_cur = DiscreteDistribution.squaredError(location, centroid);
216 if (dist_cur < distance)
217 {
218 distance = dist_cur;
219 closest = centroid;
220 }
221 }
222 clusterMap.get(closest).put(object, location);
223 }
224
225 return clusterMap;
226 }
227
228
229
230
231
232
233 public void setSeed(int random_seed)
234 {
235 this.rand = new Random(random_seed);
236 }
237
238
239
240
241
242
243
244
245
246
247 @SuppressWarnings("serial")
248 public static class NotEnoughClustersException extends RuntimeException
249 {
250 @Override
251 public String getMessage()
252 {
253 return "Not enough distinct points in the input data set to form " +
254 "the requested number of clusters";
255 }
256 }
257 }