Skip to content

Commit a49bc40

Browse files
authored
Add checked UnionFind methods (petgraph#730)
PR provides the following `UnionFind` methods, that are free from panics: ```rust pub fn try_find(&self, mut x: K) -> Option<K> pub fn try_find_mut(&mut self, x: K) -> Option<K> pub fn try_equiv(&self, x: K, y: K) -> Result<bool, K> pub fn try_union(&mut self, x: K, y: K) -> Result<bool, K> ``` Some of old methods were rewritten using the `try` twins with the addition of `.unwrap` or `.expect`, for example: ```rust pub fn find(&self, x: K) -> K { self.try_find(x).expect("The index is out of bounds") } ``` If you consider these changes to be critical & breaking, I will return back the old `assert`ions. I'm open to discussion!
1 parent 278c53b commit a49bc40

File tree

2 files changed

+152
-16
lines changed

2 files changed

+152
-16
lines changed

src/unionfind.rs

+55-16
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,25 @@ where
9595
///
9696
/// **Panics** if `x` is out of bounds.
9797
pub fn find(&self, x: K) -> K {
98-
assert!(x.index() < self.len());
99-
unsafe {
100-
let mut x = x;
101-
loop {
102-
// Use unchecked indexing because we can trust the internal set ids.
103-
let xparent = *get_unchecked(&self.parent, x.index());
104-
if xparent == x {
105-
break;
106-
}
107-
x = xparent;
98+
self.try_find(x).expect("The index is out of bounds")
99+
}
100+
101+
/// Return the representative for `x` or `None` if `x` is out of bounds.
102+
pub fn try_find(&self, mut x: K) -> Option<K> {
103+
if x.index() >= self.len() {
104+
return None;
105+
}
106+
107+
loop {
108+
// Use unchecked indexing because we can trust the internal set ids.
109+
let xparent = unsafe { *get_unchecked(&self.parent, x.index()) };
110+
if xparent == x {
111+
break;
108112
}
109-
x
113+
x = xparent;
110114
}
115+
116+
Some(x)
111117
}
112118

113119
/// Return the representative for `x`.
@@ -121,6 +127,17 @@ where
121127
unsafe { self.find_mut_recursive(x) }
122128
}
123129

130+
/// Return the representative for `x` or `None` if `x` is out of bounds.
131+
///
132+
/// Write back the found representative, flattening the internal
133+
/// datastructure in the process and quicken future lookups.
134+
pub fn try_find_mut(&mut self, x: K) -> Option<K> {
135+
if x.index() >= self.len() {
136+
return None;
137+
}
138+
Some(unsafe { self.find_mut_recursive(x) })
139+
}
140+
124141
unsafe fn find_mut_recursive(&mut self, mut x: K) -> K {
125142
let mut parent = *get_unchecked(&self.parent, x.index());
126143
while parent != x {
@@ -134,24 +151,46 @@ where
134151

135152
/// Returns `true` if the given elements belong to the same set, and returns
136153
/// `false` otherwise.
154+
///
155+
/// **Panics** if `x` or `y` is out of bounds.
137156
pub fn equiv(&self, x: K, y: K) -> bool {
138157
self.find(x) == self.find(y)
139158
}
140159

160+
/// Returns `Ok(true)` if the given elements belong to the same set, and returns
161+
/// `Ok(false)` otherwise.
162+
///
163+
/// If `x` or `y` are out of bounds, it returns `Err` with the first bad index found.
164+
pub fn try_equiv(&self, x: K, y: K) -> Result<bool, K> {
165+
let xrep = self.try_find(x).ok_or(x)?;
166+
let yrep = self.try_find(y).ok_or(y)?;
167+
Ok(xrep == yrep)
168+
}
169+
141170
/// Unify the two sets containing `x` and `y`.
142171
///
143172
/// Return `false` if the sets were already the same, `true` if they were unified.
144173
///
145174
/// **Panics** if `x` or `y` is out of bounds.
146175
pub fn union(&mut self, x: K, y: K) -> bool {
176+
self.try_union(x, y).unwrap()
177+
}
178+
179+
/// Unify the two sets containing `x` and `y`.
180+
///
181+
/// Return `Ok(false)` if the sets were already the same, `Ok(true)` if they were unified.
182+
///
183+
/// If `x` or `y` are out of bounds, it returns `Err` with first found bad index.
184+
/// But if `x == y`, the result will be `Ok(false)` even if the indexes go out of bounds.
185+
pub fn try_union(&mut self, x: K, y: K) -> Result<bool, K> {
147186
if x == y {
148-
return false;
187+
return Ok(false);
149188
}
150-
let xrep = self.find_mut(x);
151-
let yrep = self.find_mut(y);
189+
let xrep = self.try_find_mut(x).ok_or(x)?;
190+
let yrep = self.try_find_mut(y).ok_or(y)?;
152191

153192
if xrep == yrep {
154-
return false;
193+
return Ok(false);
155194
}
156195

157196
let xrepu = xrep.index();
@@ -169,7 +208,7 @@ where
169208
self.rank[xrepu] += 1;
170209
}
171210
}
172-
true
211+
Ok(true)
173212
}
174213

175214
/// Return a vector mapping each element to its representative.

tests/unionfind.rs

+97
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,37 @@ fn uf_test() {
3333
assert_eq!(set.len(), 3);
3434
}
3535

36+
#[test]
37+
fn uf_test_checked() {
38+
let n = 8;
39+
let mut u = UnionFind::new(n);
40+
for i in 0..n {
41+
assert_eq!(u.try_find(i), Some(i));
42+
assert_eq!(u.try_find_mut(i), Some(i));
43+
assert_eq!(u.try_union(i, i), Ok(false));
44+
}
45+
46+
assert!(u.try_union(0, 1).is_ok());
47+
assert_eq!(u.try_find(0), u.try_find(1));
48+
assert!(u.try_find(0).is_some());
49+
assert!(u.try_union(1, 3).is_ok());
50+
assert!(u.try_union(1, 4).is_ok());
51+
assert!(u.try_union(4, 7).is_ok());
52+
assert_eq!(u.try_find(0), u.try_find(3));
53+
assert_eq!(u.try_find(1), u.try_find(3));
54+
assert!(u.try_find(0).is_some());
55+
assert!(u.try_find(1).is_some());
56+
assert!(u.try_find(0) != u.try_find(2));
57+
assert_eq!(u.try_find(7), u.try_find(0));
58+
assert!(u.try_union(5, 6).is_ok());
59+
assert_eq!(u.try_find(6), u.try_find(5));
60+
assert!(u.try_find(6) != u.try_find(7));
61+
62+
// check that there are now 3 disjoint sets
63+
let set = (0..n).map(|i| u.find(i)).collect::<HashSet<_>>();
64+
assert_eq!(set.len(), 3);
65+
}
66+
3667
#[test]
3768
fn uf_test_with_equiv() {
3869
let n = 8;
@@ -61,6 +92,34 @@ fn uf_test_with_equiv() {
6192
assert_eq!(set.len(), 3);
6293
}
6394

95+
#[test]
96+
fn uf_test_with_checked_equiv() {
97+
let n = 8;
98+
let mut u = UnionFind::new(n);
99+
for i in 0..n {
100+
assert_eq!(u.find(i), i);
101+
assert_eq!(u.find_mut(i), i);
102+
assert_eq!(u.try_equiv(i, i), Ok(true));
103+
}
104+
105+
u.union(0, 1);
106+
assert_eq!(u.try_equiv(0, 1), Ok(true));
107+
u.union(1, 3);
108+
u.union(1, 4);
109+
u.union(4, 7);
110+
assert_eq!(u.try_equiv(0, 7), Ok(true));
111+
assert_eq!(u.try_equiv(1, 3), Ok(true));
112+
assert_eq!(u.try_equiv(0, 2), Ok(false));
113+
assert_eq!(u.try_equiv(7, 0), Ok(true));
114+
u.union(5, 6);
115+
assert_eq!(u.try_equiv(6, 5), Ok(true));
116+
assert_eq!(u.try_equiv(6, 7), Ok(false));
117+
118+
// check that there are now 3 disjoint sets
119+
let set = (0..n).map(|i| u.find(i)).collect::<HashSet<_>>();
120+
assert_eq!(set.len(), 3);
121+
}
122+
64123
#[test]
65124
fn uf_rand() {
66125
let n = 1 << 14;
@@ -89,6 +148,20 @@ fn uf_u8() {
89148
}
90149
}
91150

151+
#[test]
152+
fn uf_u8_checked() {
153+
let n = 256;
154+
let mut rng = ChaChaRng::from_rng(thread_rng()).unwrap();
155+
let mut u = UnionFind::<u8>::new(n);
156+
for _ in 0..(n * 8) {
157+
let a = rng.gen();
158+
let b = rng.gen();
159+
let ar = u.try_find(a).unwrap();
160+
let br = u.try_find(b).unwrap();
161+
assert_eq!(ar != br, u.try_union(a, b).unwrap());
162+
}
163+
}
164+
92165
#[test]
93166
fn labeling() {
94167
let mut u = UnionFind::<u32>::new(48);
@@ -136,3 +209,27 @@ fn uf_incremental() {
136209
.collect::<HashSet<_>>();
137210
assert_eq!(set.len(), 3);
138211
}
212+
213+
#[test]
214+
fn uf_test_out_of_bounds() {
215+
let n = 8;
216+
let mut u = UnionFind::new(n);
217+
for i in 0..n {
218+
u.find(i);
219+
u.find_mut(i);
220+
u.union(i, i);
221+
}
222+
223+
assert!(u.try_find(50).is_none());
224+
assert!(u.try_find_mut(50).is_none());
225+
226+
assert_eq!(u.try_union(1, 50), Err(50));
227+
assert_eq!(u.try_union(50, 1), Err(50));
228+
assert_eq!(u.try_union(30, 50), Err(30));
229+
assert_eq!(u.try_union(50, 30), Err(50));
230+
231+
assert_eq!(u.try_equiv(1, 50), Err(50));
232+
assert_eq!(u.try_equiv(50, 1), Err(50));
233+
assert_eq!(u.try_equiv(30, 50), Err(30));
234+
assert_eq!(u.try_equiv(50, 30), Err(50));
235+
}

0 commit comments

Comments
 (0)