1
2
3
4
5
6
7
8
9
10
11
12 package edu.uci.ics.jung.algorithms.util;
13
14 import java.util.ArrayList;
15 import java.util.LinkedList;
16 import java.util.List;
17 import java.util.Map;
18 import java.util.Queue;
19 import java.util.Random;
20
21
22
23
24
25
26
27
28
29
30
31
32 public class WeightedChoice<T>
33 {
34 private List<ItemPair> item_pairs;
35 private Random random;
36
37
38
39
40
41 public static final double DEFAULT_THRESHOLD = 0.00000000001;
42
43
44
45
46
47 public WeightedChoice(Map<T, ? extends Number> item_weights)
48 {
49 this(item_weights, new Random(), DEFAULT_THRESHOLD);
50 }
51
52
53
54
55
56
57
58 public WeightedChoice(Map<T, ? extends Number> item_weights, double threshold)
59 {
60 this(item_weights, new Random(), threshold);
61 }
62
63
64
65
66
67
68 public WeightedChoice(Map<T, ? extends Number> item_weights, Random random)
69 {
70 this(item_weights, random, DEFAULT_THRESHOLD);
71 }
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88 public WeightedChoice(Map<T, ? extends Number> item_weights, Random random,
89 double threshold)
90 {
91 if (item_weights.isEmpty())
92 throw new IllegalArgumentException("Item weights must be non-empty");
93
94 int item_count = item_weights.size();
95 item_pairs = new ArrayList<ItemPair>(item_count);
96
97 double sum = 0;
98 for (Map.Entry<T, ? extends Number> entry : item_weights.entrySet())
99 {
100 double value = entry.getValue().doubleValue();
101 if (value <= 0)
102 throw new IllegalArgumentException("Weights must be > 0");
103 sum += value;
104 }
105 double bucket_weight = 1.0 / item_weights.size();
106
107 Queue<ItemPair> light_weights = new LinkedList<ItemPair>();
108 Queue<ItemPair> heavy_weights = new LinkedList<ItemPair>();
109 for (Map.Entry<T, ? extends Number> entry : item_weights.entrySet())
110 {
111 double value = entry.getValue().doubleValue() / sum;
112 enqueueItem(entry.getKey(), value, bucket_weight, light_weights, heavy_weights);
113 }
114
115
116 while (!heavy_weights.isEmpty() || !light_weights.isEmpty())
117 {
118 ItemPair heavy_item = heavy_weights.poll();
119 ItemPair light_item = light_weights.poll();
120 double light_weight = 0;
121 T light = null;
122 T heavy = null;
123 if (light_item != null)
124 {
125 light_weight = light_item.weight;
126 light = light_item.light;
127 }
128 if (heavy_item != null)
129 {
130 heavy = heavy_item.heavy;
131
132
133
134 double new_weight = heavy_item.weight - (bucket_weight - light_weight);
135 if (new_weight > threshold)
136 enqueueItem(heavy, new_weight, bucket_weight, light_weights, heavy_weights);
137 }
138 light_weight *= item_count;
139
140 item_pairs.add(new ItemPair(light, heavy, light_weight));
141 }
142
143 this.random = random;
144 }
145
146
147
148
149
150
151 private void enqueueItem(T key, double value, double threshold,
152 Queue<ItemPair> light_weights, Queue<ItemPair> heavy_weights)
153 {
154 if (value < threshold)
155 light_weights.offer(new ItemPair(key, null, value));
156 else
157 heavy_weights.offer(new ItemPair(null, key, value));
158 }
159
160
161
162
163 public void setRandomSeed(long seed)
164 {
165 this.random.setSeed(seed);
166 }
167
168
169
170
171
172
173 public T nextItem()
174 {
175 ItemPair item_pair = item_pairs.get(random.nextInt(item_pairs.size()));
176 if (random.nextDouble() < item_pair.weight)
177 return item_pair.light;
178 return item_pair.heavy;
179 }
180
181
182
183
184 private class ItemPair
185 {
186 T light;
187 T heavy;
188 double weight;
189
190 private ItemPair(T light, T heavy, double weight)
191 {
192 this.light = light;
193 this.heavy = heavy;
194 this.weight = weight;
195 }
196
197 @Override
198 public String toString()
199 {
200 return String.format("[L:%s, H:%s, %.3f]", light, heavy, weight);
201 }
202 }
203 }