@@ -343,6 +343,16 @@ impl RowSelection {
343
343
intersect_row_selections ( & self . selectors , & other. selectors )
344
344
}
345
345
346
+ /// Compute the union of two [`RowSelection`]
347
+ /// For example:
348
+ /// self: NNYYYYNNYYNYN
349
+ /// other: NYNNNNNNN
350
+ ///
351
+ /// returned: NYYYYYNNYYNYN
352
+ pub fn union ( & self , other : & Self ) -> Self {
353
+ union_row_selections ( & self . selectors , & other. selectors )
354
+ }
355
+
346
356
/// Returns `true` if this [`RowSelection`] selects any rows
347
357
pub fn selects_any ( & self ) -> bool {
348
358
self . selectors . iter ( ) . any ( |x| !x. skip )
@@ -536,6 +546,92 @@ fn intersect_row_selections(left: &[RowSelector], right: &[RowSelector]) -> RowS
536
546
iter. collect ( )
537
547
}
538
548
549
+ /// Combine two lists of `RowSelector` return the union of them
550
+ /// For example:
551
+ /// self: NNYYYYNNYYNYN
552
+ /// other: NYNNNNNNY
553
+ ///
554
+ /// returned: NYYYYYNNYYNYN
555
+ ///
556
+ /// This can be removed from here once RowSelection::union is in parquet::arrow
557
+ fn union_row_selections ( left : & [ RowSelector ] , right : & [ RowSelector ] ) -> RowSelection {
558
+ let mut l_iter = left. iter ( ) . copied ( ) . peekable ( ) ;
559
+ let mut r_iter = right. iter ( ) . copied ( ) . peekable ( ) ;
560
+
561
+ let iter = std:: iter:: from_fn ( move || {
562
+ loop {
563
+ let l = l_iter. peek_mut ( ) ;
564
+ let r = r_iter. peek_mut ( ) ;
565
+
566
+ match ( l, r) {
567
+ ( Some ( a) , _) if a. row_count == 0 => {
568
+ l_iter. next ( ) . unwrap ( ) ;
569
+ }
570
+ ( _, Some ( b) ) if b. row_count == 0 => {
571
+ r_iter. next ( ) . unwrap ( ) ;
572
+ }
573
+ ( Some ( l) , Some ( r) ) => {
574
+ return match ( l. skip , r. skip ) {
575
+ // Skip both ranges
576
+ ( true , true ) => {
577
+ if l. row_count < r. row_count {
578
+ let skip = l. row_count ;
579
+ r. row_count -= l. row_count ;
580
+ l_iter. next ( ) ;
581
+ Some ( RowSelector :: skip ( skip) )
582
+ } else {
583
+ let skip = r. row_count ;
584
+ l. row_count -= skip;
585
+ r_iter. next ( ) ;
586
+ Some ( RowSelector :: skip ( skip) )
587
+ }
588
+ }
589
+ // Keep rows from left
590
+ ( false , true ) => {
591
+ if l. row_count < r. row_count {
592
+ r. row_count -= l. row_count ;
593
+ l_iter. next ( )
594
+ } else {
595
+ let r_row_count = r. row_count ;
596
+ l. row_count -= r_row_count;
597
+ r_iter. next ( ) ;
598
+ Some ( RowSelector :: select ( r_row_count) )
599
+ }
600
+ }
601
+ // Keep rows from right
602
+ ( true , false ) => {
603
+ if l. row_count < r. row_count {
604
+ let l_row_count = l. row_count ;
605
+ r. row_count -= l_row_count;
606
+ l_iter. next ( ) ;
607
+ Some ( RowSelector :: select ( l_row_count) )
608
+ } else {
609
+ l. row_count -= r. row_count ;
610
+ r_iter. next ( )
611
+ }
612
+ }
613
+ // Keep at least one
614
+ _ => {
615
+ if l. row_count < r. row_count {
616
+ r. row_count -= l. row_count ;
617
+ l_iter. next ( )
618
+ } else {
619
+ l. row_count -= r. row_count ;
620
+ r_iter. next ( )
621
+ }
622
+ }
623
+ } ;
624
+ }
625
+ ( Some ( _) , None ) => return l_iter. next ( ) ,
626
+ ( None , Some ( _) ) => return r_iter. next ( ) ,
627
+ ( None , None ) => return None ,
628
+ }
629
+ }
630
+ } ) ;
631
+
632
+ iter. collect ( )
633
+ }
634
+
539
635
#[ cfg( test) ]
540
636
mod tests {
541
637
use super :: * ;
@@ -1213,4 +1309,40 @@ mod tests {
1213
1309
]
1214
1310
) ;
1215
1311
}
1312
+
1313
+ #[ test]
1314
+ fn test_union ( ) {
1315
+ let selection = RowSelection :: from ( vec ! [ RowSelector :: select( 1048576 ) ] ) ;
1316
+ let result = selection. union ( & selection) ;
1317
+ assert_eq ! ( result, selection) ;
1318
+
1319
+ // NYNYY
1320
+ let a = RowSelection :: from ( vec ! [
1321
+ RowSelector :: skip( 10 ) ,
1322
+ RowSelector :: select( 10 ) ,
1323
+ RowSelector :: skip( 10 ) ,
1324
+ RowSelector :: select( 20 ) ,
1325
+ ] ) ;
1326
+
1327
+ // NNYYNYN
1328
+ let b = RowSelection :: from ( vec ! [
1329
+ RowSelector :: skip( 20 ) ,
1330
+ RowSelector :: select( 20 ) ,
1331
+ RowSelector :: skip( 10 ) ,
1332
+ RowSelector :: select( 10 ) ,
1333
+ RowSelector :: skip( 10 ) ,
1334
+ ] ) ;
1335
+
1336
+ let result = a. union ( & b) ;
1337
+
1338
+ // NYYYYYN
1339
+ assert_eq ! (
1340
+ result. iter( ) . collect:: <Vec <_>>( ) ,
1341
+ vec![
1342
+ & RowSelector :: skip( 10 ) ,
1343
+ & RowSelector :: select( 50 ) ,
1344
+ & RowSelector :: skip( 10 ) ,
1345
+ ]
1346
+ ) ;
1347
+ }
1216
1348
}
0 commit comments