Skip to content

Commit e674d9a

Browse files
committed
union find
1 parent b3a246b commit e674d9a

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

graphs/unionfind.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# For disjoint set of nodes, we build trees that are not necessarily accurate with
2+
# the graph, we try to keep them balanced instead to make the find operation efficient
3+
# Uses:
4+
# Find if a graph has a cycle
5+
# Find number of connected components
6+
7+
class UnionFind():
8+
def __init__(self, n: int):
9+
self.par = {}
10+
self.rank = {}
11+
for i in range(n):
12+
self.par[i] = i
13+
self.rank[i] = 1
14+
def union(self, n1: int, n2: int) -> bool:
15+
p1, p2 = self.find(n1), self.find(n2)
16+
if p1 == p2: return False
17+
if self.rank[p1] < self.rank[p2]:
18+
self.par[p1] = p2
19+
elif self.rank[p1] > self.rank[p2]:
20+
self.par[p2] = p1
21+
else:
22+
self.par[p1] = p2
23+
self.rank[p2] += 1 # this is the only case where we need to adjust rank
24+
return True
25+
26+
def find(self, n: int) -> int:
27+
curr = n
28+
while self.par[curr] != curr:
29+
self.par[curr] = self.par[self.par[curr]] # path compression
30+
curr = self.par[curr]
31+
return curr
32+
33+
uf1 = UnionFind(3)
34+
for e1,e2 in [[0,1],[1,2],[0,2]]: # edges, has cycle
35+
if not uf1.union(e1, e2):
36+
print(f'cycle detected with edge: {e1} -> {e2}') # cycle detected with edge: 0 -> 2
37+
38+
print(f'{len(set(uf1.par.values()))} disjoint sets') # 1 disjoint sets
39+
40+
uf2 = UnionFind(3)
41+
for e1,e2 in [[0,1]]: # only one edge
42+
if not uf2.union(e1, e2):
43+
print(f'cycle detected with edge: {e1} -> {e2}')
44+
print(f'{len(set(uf2.par.values()))} disjoint sets') # 2 disjoint sets
45+

0 commit comments

Comments
 (0)