View Javadoc
1   /**
2    * Copyright (c) 2009, The JUNG Authors 
3    *
4    * All rights reserved.
5    *
6    * This software is open-source under the BSD license; see either
7    * "license.txt" or
8    * https://github.com/jrtom/jung/blob/master/LICENSE for a description.
9    * Created on Jan 8, 2009
10   * 
11   */
12  package edu.uci.ics.jung.algorithms.util;
13  
14  import java.util.ArrayList;
15  import java.util.LinkedList;
16  import java.util.List;
17  import java.util.Map;
18  import java.util.Queue;
19  import java.util.Random;
20  
21  /**
22   * Selects items according to their probability in an arbitrary probability 
23   * distribution.  The distribution is specified by a {@code Map} from
24   * items (of type {@code T}) to weights of type {@code Number}, supplied
25   * to the constructor; these weights are normalized internally to act as 
26   * probabilities.
27   * 
28   * <p>This implementation selects items in O(1) time, and requires O(n) space.
29   * 
30   * @author Joshua O'Madadhain
31   */
32  public class WeightedChoice<T> 
33  {
34  	private List<ItemPair> item_pairs;
35  	private Random random;
36  	
37  	/**
38  	 * The default minimum value that is treated as a valid probability 
39  	 * (as opposed to rounding error from floating-point operations). 
40  	 */
41  	public static final double DEFAULT_THRESHOLD = 0.00000000001;
42  
43  	/**
44  	 * Equivalent to {@code this(item_weights, new Random(), DEFAULT_THRESHOLD)}.
45  	 * @param item_weights a map from items to their weights
46  	 */
47  	public WeightedChoice(Map<T, ? extends Number> item_weights)
48  	{
49  		this(item_weights, new Random(), DEFAULT_THRESHOLD);
50  	}
51  
52  	/**
53  	 * Equivalent to {@code this(item_weights, new Random(), threshold)}.
54  	 * @param item_weights a map from items to their weights
55  	 * @param threshold the minimum value that is treated as a probability
56  	 *     (anything smaller will be considered equivalent to a floating-point rounding error)
57  	 */
58  	public WeightedChoice(Map<T, ? extends Number> item_weights, double threshold)
59  	{
60  		this(item_weights, new Random(), threshold);
61  	}
62  	
63  	/**
64  	 * Equivalent to {@code this(item_weights, random, DEFAULT_THRESHOLD)}.
65  	 * @param item_weights a map from items to their weights
66  	 * @param random the Random instance to use for selection
67  	 */
68  	public WeightedChoice(Map<T, ? extends Number> item_weights, Random random)
69  	{
70  		this(item_weights, random, DEFAULT_THRESHOLD);
71  	}
72  	
73  	/**
74  	 * Creates an instance with the specified mapping from items to weights,
75  	 * random number generator, and threshold value.
76  	 * 
77  	 * <p>The mapping defines the weight for each item to be selected; this 
78  	 * will be proportional to the probability of its selection.
79  	 * <p>The random number generator specifies the mechanism which will be
80  	 * used to provide uniform integer and double values.
81  	 * <p>The threshold indicates default minimum value that is treated as a valid 
82  	 * probability (as opposed to rounding error from floating-point operations). 
83  	 * @param item_weights a map from items to their weights
84  	 * @param random the Random instance to use for selection
85  	 * @param threshold the minimum value that is treated as a probability
86  	 *     (anything smaller will be considered equivalent to a floating-point rounding error)
87  	 */
88  	public WeightedChoice(Map<T, ? extends Number> item_weights, Random random,
89  			double threshold) 
90  	{
91  		if (item_weights.isEmpty())
92  			throw new IllegalArgumentException("Item weights must be non-empty");
93  		
94  		int item_count = item_weights.size();
95  		item_pairs = new ArrayList<ItemPair>(item_count);
96  		
97  		double sum = 0;
98  		for (Map.Entry<T, ? extends Number> entry : item_weights.entrySet())
99  		{
100 			double value = entry.getValue().doubleValue();
101 			if (value <= 0)
102 				throw new IllegalArgumentException("Weights must be > 0");
103 			sum += value;
104 		}
105         double bucket_weight = 1.0 / item_weights.size();
106 		
107 		Queue<ItemPair> light_weights = new LinkedList<ItemPair>();
108 		Queue<ItemPair> heavy_weights = new LinkedList<ItemPair>();
109 		for (Map.Entry<T, ? extends Number> entry : item_weights.entrySet())
110 		{
111 			double value = entry.getValue().doubleValue() / sum;
112 			enqueueItem(entry.getKey(), value, bucket_weight, light_weights, heavy_weights);
113 		}
114 		
115 		// repeat until both queues empty
116 		while (!heavy_weights.isEmpty() || !light_weights.isEmpty())
117 		{
118 			ItemPair heavy_item = heavy_weights.poll();
119 			ItemPair light_item = light_weights.poll();
120 			double light_weight = 0;
121 			T light = null;
122 			T heavy = null;
123 			if (light_item != null)
124 			{
125 				light_weight = light_item.weight;
126 				light = light_item.light;
127 			}
128 			if (heavy_item != null)
129 			{
130 				heavy = heavy_item.heavy;
131 				// put the 'left over' weight from the heavy item--what wasn't
132 				// needed to make up the difference between the light weight and
133 				// 1/n--back in the appropriate queue
134 				double new_weight = heavy_item.weight - (bucket_weight - light_weight);
135 				if (new_weight > threshold)
136 					enqueueItem(heavy, new_weight, bucket_weight, light_weights, heavy_weights);
137 			}
138 			light_weight *= item_count;
139 			
140 			item_pairs.add(new ItemPair(light, heavy, light_weight));
141 		}
142 		
143 		this.random = random;
144 	}
145 
146 	/**
147 	 * Adds key/value to the appropriate queue.  Keys with values less than
148 	 * the threshold get added to {@code light_weights}, all others get added
149 	 * to {@code heavy_weights}.
150 	 */
151 	private void enqueueItem(T key, double value, double threshold, 
152 			Queue<ItemPair> light_weights, Queue<ItemPair> heavy_weights)
153 	{
154 		if (value < threshold) 
155 			light_weights.offer(new ItemPair(key, null, value));
156 		else
157 			heavy_weights.offer(new ItemPair(null, key, value));
158 	}
159 	
160 	/**
161 	 * @param seed the seed to be used by the internal random number generator
162 	 */
163 	public void setRandomSeed(long seed)
164 	{
165 		this.random.setSeed(seed);
166 	}
167 	
168 	/**
169 	 * Retrieves an item with probability proportional to its weight in the
170 	 * {@code Map} provided in the input.
171 	 * @return an item chosen randomly based on its specified weight
172 	 */
173 	public T nextItem()
174 	{
175 		ItemPair item_pair = item_pairs.get(random.nextInt(item_pairs.size()));
176 		if (random.nextDouble() < item_pair.weight)
177 			return item_pair.light;
178 		return item_pair.heavy;
179 	}
180 	
181 	/**
182 	 * Manages light object/heavy object/light conditional probability tuples.
183 	 */
184 	private class ItemPair 
185 	{
186 		T light;
187 		T heavy;
188 		double weight;
189 		
190 		private ItemPair(T light, T heavy, double weight)
191 		{
192 			this.light = light;
193 			this.heavy = heavy;
194 			this.weight = weight;
195 		}
196 		
197 		@Override
198         public String toString()
199 		{
200 			return String.format("[L:%s, H:%s, %.3f]", light, heavy, weight);
201 		}
202 	}
203 }