1 package edu.uci.ics.jung.algorithms.shortestpath;
2
3 import java.util.Collection;
4 import java.util.HashSet;
5 import java.util.Set;
6
7 import com.google.common.base.Function;
8 import com.google.common.base.Functions;
9 import com.google.common.base.Supplier;
10
11 import edu.uci.ics.jung.graph.Graph;
12 import edu.uci.ics.jung.graph.util.Pair;
13
14
15
16
17
18
19
20
21
22
23 public class PrimMinimumSpanningTree<V,E> implements Function<Graph<V,E>,Graph<V,E>> {
24
25 protected Supplier<? extends Graph<V,E>> treeFactory;
26 protected Function<? super E,Double> weights;
27
28
29
30
31
32 public PrimMinimumSpanningTree(Supplier<? extends Graph<V,E>> supplier) {
33 this(supplier, Functions.constant(1.0));
34 }
35
36
37
38
39
40
41 public PrimMinimumSpanningTree(Supplier<? extends Graph<V,E>> supplier,
42 Function<? super E, Double> weights) {
43 this.treeFactory = supplier;
44 if(weights != null) {
45 this.weights = weights;
46 }
47 }
48
49
50
51
52 public Graph<V,E> apply(Graph<V,E> graph) {
53 Set<E> unfinishedEdges = new HashSet<E>(graph.getEdges());
54 Graph<V,E> tree = treeFactory.get();
55 V root = findRoot(graph);
56 if(graph.getVertices().contains(root)) {
57 tree.addVertex(root);
58 } else if(graph.getVertexCount() > 0) {
59
60 tree.addVertex(graph.getVertices().iterator().next());
61 }
62 updateTree(tree, graph, unfinishedEdges);
63
64 return tree;
65 }
66
67 protected V findRoot(Graph<V,E> graph) {
68 for(V v : graph.getVertices()) {
69 if(graph.getInEdges(v).size() == 0) {
70 return v;
71 }
72 }
73
74 if(graph.getVertexCount() > 0) {
75 return graph.getVertices().iterator().next();
76 }
77
78 return null;
79 }
80
81 protected void updateTree(Graph<V,E> tree, Graph<V,E> graph, Collection<E> unfinishedEdges) {
82 Collection<V> tv = tree.getVertices();
83 double minCost = Double.MAX_VALUE;
84 E nextEdge = null;
85 V nextVertex = null;
86 V currentVertex = null;
87 for(E e : unfinishedEdges) {
88
89 if(tree.getEdges().contains(e)) continue;
90
91
92 Pair<V> endpoints = graph.getEndpoints(e);
93 V first = endpoints.getFirst();
94 V second = endpoints.getSecond();
95 if((tv.contains(first) == true && tv.contains(second) == false)) {
96 if(weights.apply(e) < minCost) {
97 minCost = weights.apply(e);
98 nextEdge = e;
99 currentVertex = first;
100 nextVertex = second;
101 }
102 } else if((tv.contains(second) == true && tv.contains(first) == false)) {
103 if(weights.apply(e) < minCost) {
104 minCost = weights.apply(e);
105 nextEdge = e;
106 currentVertex = second;
107 nextVertex = first;
108 }
109 }
110 }
111
112 if(nextVertex != null && nextEdge != null) {
113 unfinishedEdges.remove(nextEdge);
114 tree.addEdge(nextEdge, currentVertex, nextVertex);
115 updateTree(tree, graph, unfinishedEdges);
116 }
117 }
118 }