View Javadoc
1   package edu.uci.ics.jung.algorithms.shortestpath;
2   
3   import java.util.Collection;
4   import java.util.HashSet;
5   import java.util.Set;
6   
7   import com.google.common.base.Function;
8   import com.google.common.base.Functions;
9   import com.google.common.base.Supplier;
10  
11  import edu.uci.ics.jung.graph.Graph;
12  import edu.uci.ics.jung.graph.util.Pair;
13  
14  /**
15   * For the input Graph, creates a MinimumSpanningTree
16   * using a variation of Prim's algorithm.
17   * 
18   * @author Tom Nelson - tomnelson@dev.java.net
19   *
20   * @param <V> the vertex type
21   * @param <E> the edge type
22   */
23  public class PrimMinimumSpanningTree<V,E> implements Function<Graph<V,E>,Graph<V,E>> {
24  	
25  	protected Supplier<? extends Graph<V,E>> treeFactory;
26  	protected Function<? super E,Double> weights; 
27  	
28  	/**
29  	 * Creates an instance which generates a minimum spanning tree assuming constant edge weights.
30  	 * @param supplier used to create the tree instances
31  	 */
32  	public PrimMinimumSpanningTree(Supplier<? extends Graph<V,E>> supplier) {
33  		this(supplier, Functions.constant(1.0));
34  	}
35  
36      /**
37       * Creates an instance which generates a minimum spanning tree using the input edge weights.
38  	 * @param supplier used to create the tree instances
39  	 * @param weights the edge weights to use for defining the MST
40       */
41  	public PrimMinimumSpanningTree(Supplier<? extends Graph<V,E>> supplier, 
42  			Function<? super E, Double> weights) {
43  		this.treeFactory = supplier;
44  		if(weights != null) {
45  			this.weights = weights;
46  		}
47  	}
48  	
49  	/**
50  	 * @param graph the Graph to find MST in
51  	 */
52      public Graph<V,E> apply(Graph<V,E> graph) {
53  		Set<E> unfinishedEdges = new HashSet<E>(graph.getEdges());
54  		Graph<V,E> tree = treeFactory.get();
55  		V root = findRoot(graph);
56  		if(graph.getVertices().contains(root)) {
57  			tree.addVertex(root);
58  		} else if(graph.getVertexCount() > 0) {
59  			// pick an arbitrary vertex to make root
60  			tree.addVertex(graph.getVertices().iterator().next());
61  		}
62  		updateTree(tree, graph, unfinishedEdges);
63  		
64  		return tree;
65  	}
66      
67      protected V findRoot(Graph<V,E> graph) {
68      	for(V v : graph.getVertices()) {
69      		if(graph.getInEdges(v).size() == 0) {
70      			return v;
71      		}
72      	}
73      	// if there is no obvious root, pick any vertex
74      	if(graph.getVertexCount() > 0) {
75      		return graph.getVertices().iterator().next();
76      	}
77      	// this graph has no vertices
78      	return null;
79      }
80  	
81  	protected void updateTree(Graph<V,E> tree, Graph<V,E> graph, Collection<E> unfinishedEdges) {
82  		Collection<V> tv = tree.getVertices();
83  		double minCost = Double.MAX_VALUE;
84  		E nextEdge = null;
85  		V nextVertex = null;
86  		V currentVertex = null;
87  		for(E e : unfinishedEdges) {
88  			
89  			if(tree.getEdges().contains(e)) continue;
90  			// find the lowest cost edge, get its opposite endpoint,
91  			// and then update forest from its Successors
92  			Pair<V> endpoints = graph.getEndpoints(e);
93  			V first = endpoints.getFirst();
94  			V second = endpoints.getSecond();
95  			if((tv.contains(first) == true && tv.contains(second) == false)) {
96  				if(weights.apply(e) < minCost) {
97  					minCost = weights.apply(e);
98  					nextEdge = e;
99  					currentVertex = first;
100 					nextVertex = second;
101 				}
102 			} else if((tv.contains(second) == true && tv.contains(first) == false)) {
103 				if(weights.apply(e) < minCost) {
104 					minCost = weights.apply(e);
105 					nextEdge = e;
106 					currentVertex = second;
107 					nextVertex = first;
108 				}
109 			}
110 		}
111 		
112 		if(nextVertex != null && nextEdge != null) {
113 			unfinishedEdges.remove(nextEdge);
114 			tree.addEdge(nextEdge, currentVertex, nextVertex);
115 			updateTree(tree, graph, unfinishedEdges);
116 		}
117 	}
118 }