about summary refs log tree commit diff
path: root/users/wpcarro/scratch/facebook/mst.py
blob: 81aa5cd487443a28e881f5b0dbb7c5eeda6010c8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from heapq import heappush, heappop
import random

def to_vertex_list(graph):
    result = {}
    for a, b, kg in graph:
        if a in result:
            result[a].append((b, kg))
        else:
            result[a] = [(b, kg)]
        if b in result:
            result[b].append((a, kg))
        else:
            result[b] = [(a, kg)]
    return result

def mst(graph):
    graph = to_vertex_list(graph)
    beg = random.choice(list(graph.keys()))
    h = []
    result = []
    seen = set()
    for c, kg in graph[beg]:
        heappush(h, (kg, beg, c))
    while h:
        kg, beg, end = heappop(h)
        # detect cycles
        if end in seen:
            continue
        # use the edge
        seen.add(beg)
        seen.add(end)
        result.append((beg, end))
        for c, kg in graph[end]:
            heappush(h, (kg, end, c))
    return result

graphs = [
    [
        ('A', 'B', 7),
        ('A', 'D', 5),
        ('B', 'D', 9),
        ('E', 'D', 15),
        ('F', 'D', 6),
        ('F', 'G', 11),
        ('F', 'E', 8),
        ('G', 'E', 9),
        ('C', 'E', 5),
        ('B', 'E', 7),
        ('B', 'C', 8),
    ],
    [
        ('A', 'B', 4),
        ('A', 'C', 8),
        ('B', 'C', 11),
        ('B', 'E', 8),
        ('C', 'D', 7),
        ('C', 'F', 1),
        ('D', 'E', 2),
        ('D', 'F', 6),
        ('E', 'G', 7),
        ('E', 'H', 4),
        ('F', 'H', 2),
        ('G', 'H', 14),
        ('G', 'I', 9),
        ('H', 'I', 10),
    ],
]

for graph in graphs:
    print(mst(graph))