Skip to content

Commit b711f23

Browse files
authored
feat(parquet): add union method to RowSelection (#6308)
Complements the existing RowSelection::intersection method. Useful for Or-ing row selections together, in contrast to intersection's use when AND-ing selections
1 parent ee2f75a commit b711f23

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

parquet/src/arrow/arrow_reader/selection.rs

+132
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ impl RowSelection {
343343
intersect_row_selections(&self.selectors, &other.selectors)
344344
}
345345

346+
/// Compute the union of two [`RowSelection`]
347+
/// For example:
348+
/// self: NNYYYYNNYYNYN
349+
/// other: NYNNNNNNN
350+
///
351+
/// returned: NYYYYYNNYYNYN
352+
pub fn union(&self, other: &Self) -> Self {
353+
union_row_selections(&self.selectors, &other.selectors)
354+
}
355+
346356
/// Returns `true` if this [`RowSelection`] selects any rows
347357
pub fn selects_any(&self) -> bool {
348358
self.selectors.iter().any(|x| !x.skip)
@@ -536,6 +546,92 @@ fn intersect_row_selections(left: &[RowSelector], right: &[RowSelector]) -> RowS
536546
iter.collect()
537547
}
538548

549+
/// Combine two lists of `RowSelector` return the union of them
550+
/// For example:
551+
/// self: NNYYYYNNYYNYN
552+
/// other: NYNNNNNNY
553+
///
554+
/// returned: NYYYYYNNYYNYN
555+
///
556+
/// This can be removed from here once RowSelection::union is in parquet::arrow
557+
fn union_row_selections(left: &[RowSelector], right: &[RowSelector]) -> RowSelection {
558+
let mut l_iter = left.iter().copied().peekable();
559+
let mut r_iter = right.iter().copied().peekable();
560+
561+
let iter = std::iter::from_fn(move || {
562+
loop {
563+
let l = l_iter.peek_mut();
564+
let r = r_iter.peek_mut();
565+
566+
match (l, r) {
567+
(Some(a), _) if a.row_count == 0 => {
568+
l_iter.next().unwrap();
569+
}
570+
(_, Some(b)) if b.row_count == 0 => {
571+
r_iter.next().unwrap();
572+
}
573+
(Some(l), Some(r)) => {
574+
return match (l.skip, r.skip) {
575+
// Skip both ranges
576+
(true, true) => {
577+
if l.row_count < r.row_count {
578+
let skip = l.row_count;
579+
r.row_count -= l.row_count;
580+
l_iter.next();
581+
Some(RowSelector::skip(skip))
582+
} else {
583+
let skip = r.row_count;
584+
l.row_count -= skip;
585+
r_iter.next();
586+
Some(RowSelector::skip(skip))
587+
}
588+
}
589+
// Keep rows from left
590+
(false, true) => {
591+
if l.row_count < r.row_count {
592+
r.row_count -= l.row_count;
593+
l_iter.next()
594+
} else {
595+
let r_row_count = r.row_count;
596+
l.row_count -= r_row_count;
597+
r_iter.next();
598+
Some(RowSelector::select(r_row_count))
599+
}
600+
}
601+
// Keep rows from right
602+
(true, false) => {
603+
if l.row_count < r.row_count {
604+
let l_row_count = l.row_count;
605+
r.row_count -= l_row_count;
606+
l_iter.next();
607+
Some(RowSelector::select(l_row_count))
608+
} else {
609+
l.row_count -= r.row_count;
610+
r_iter.next()
611+
}
612+
}
613+
// Keep at least one
614+
_ => {
615+
if l.row_count < r.row_count {
616+
r.row_count -= l.row_count;
617+
l_iter.next()
618+
} else {
619+
l.row_count -= r.row_count;
620+
r_iter.next()
621+
}
622+
}
623+
};
624+
}
625+
(Some(_), None) => return l_iter.next(),
626+
(None, Some(_)) => return r_iter.next(),
627+
(None, None) => return None,
628+
}
629+
}
630+
});
631+
632+
iter.collect()
633+
}
634+
539635
#[cfg(test)]
540636
mod tests {
541637
use super::*;
@@ -1213,4 +1309,40 @@ mod tests {
12131309
]
12141310
);
12151311
}
1312+
1313+
#[test]
1314+
fn test_union() {
1315+
let selection = RowSelection::from(vec![RowSelector::select(1048576)]);
1316+
let result = selection.union(&selection);
1317+
assert_eq!(result, selection);
1318+
1319+
// NYNYY
1320+
let a = RowSelection::from(vec![
1321+
RowSelector::skip(10),
1322+
RowSelector::select(10),
1323+
RowSelector::skip(10),
1324+
RowSelector::select(20),
1325+
]);
1326+
1327+
// NNYYNYN
1328+
let b = RowSelection::from(vec![
1329+
RowSelector::skip(20),
1330+
RowSelector::select(20),
1331+
RowSelector::skip(10),
1332+
RowSelector::select(10),
1333+
RowSelector::skip(10),
1334+
]);
1335+
1336+
let result = a.union(&b);
1337+
1338+
// NYYYYYN
1339+
assert_eq!(
1340+
result.iter().collect::<Vec<_>>(),
1341+
vec![
1342+
&RowSelector::skip(10),
1343+
&RowSelector::select(50),
1344+
&RowSelector::skip(10),
1345+
]
1346+
);
1347+
}
12161348
}

0 commit comments

Comments
 (0)