Skip to content

Commit ac2fa4d

Browse files
authored
Merge pull request #925 from jakeKonrad/master
Added A `par_chunk_by` method
2 parents 9ee7649 + e37ec9e commit ac2fa4d

File tree

5 files changed

+342
-0
lines changed

5 files changed

+342
-0
lines changed

src/slice/chunk_by.rs

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
use crate::iter::plumbing::*;
2+
use crate::iter::*;
3+
use std::marker::PhantomData;
4+
use std::{fmt, mem};
5+
6+
trait ChunkBySlice<T>: AsRef<[T]> + Default + Send {
7+
fn split(self, index: usize) -> (Self, Self);
8+
9+
fn find(&self, pred: &impl Fn(&T, &T) -> bool, start: usize, end: usize) -> Option<usize> {
10+
self.as_ref()[start..end]
11+
.windows(2)
12+
.position(move |w| !pred(&w[0], &w[1]))
13+
.map(|i| i + 1)
14+
}
15+
16+
fn rfind(&self, pred: &impl Fn(&T, &T) -> bool, end: usize) -> Option<usize> {
17+
self.as_ref()[..end]
18+
.windows(2)
19+
.rposition(move |w| !pred(&w[0], &w[1]))
20+
.map(|i| i + 1)
21+
}
22+
}
23+
24+
impl<T: Sync> ChunkBySlice<T> for &[T] {
25+
fn split(self, index: usize) -> (Self, Self) {
26+
self.split_at(index)
27+
}
28+
}
29+
30+
impl<T: Send> ChunkBySlice<T> for &mut [T] {
31+
fn split(self, index: usize) -> (Self, Self) {
32+
self.split_at_mut(index)
33+
}
34+
}
35+
36+
struct ChunkByProducer<'p, T, Slice, Pred> {
37+
slice: Slice,
38+
pred: &'p Pred,
39+
tail: usize,
40+
marker: PhantomData<fn(&T)>,
41+
}
42+
43+
// Note: this implementation is very similar to `SplitProducer`.
44+
impl<T, Slice, Pred> UnindexedProducer for ChunkByProducer<'_, T, Slice, Pred>
45+
where
46+
Slice: ChunkBySlice<T>,
47+
Pred: Fn(&T, &T) -> bool + Send + Sync,
48+
{
49+
type Item = Slice;
50+
51+
fn split(self) -> (Self, Option<Self>) {
52+
if self.tail < 2 {
53+
return (Self { tail: 0, ..self }, None);
54+
}
55+
56+
// Look forward for the separator, and failing that look backward.
57+
let mid = self.tail / 2;
58+
let index = match self.slice.find(self.pred, mid, self.tail) {
59+
Some(i) => Some(mid + i),
60+
None => self.slice.rfind(self.pred, mid + 1),
61+
};
62+
63+
if let Some(index) = index {
64+
let (left, right) = self.slice.split(index);
65+
66+
let (left_tail, right_tail) = if index <= mid {
67+
// If we scanned backwards to find the separator, everything in
68+
// the right side is exhausted, with no separators left to find.
69+
(index, 0)
70+
} else {
71+
(mid + 1, self.tail - index)
72+
};
73+
74+
// Create the left split before the separator.
75+
let left = Self {
76+
slice: left,
77+
tail: left_tail,
78+
..self
79+
};
80+
81+
// Create the right split following the separator.
82+
let right = Self {
83+
slice: right,
84+
tail: right_tail,
85+
..self
86+
};
87+
88+
(left, Some(right))
89+
} else {
90+
// The search is exhausted, no more separators...
91+
(Self { tail: 0, ..self }, None)
92+
}
93+
}
94+
95+
fn fold_with<F>(self, mut folder: F) -> F
96+
where
97+
F: Folder<Self::Item>,
98+
{
99+
let Self {
100+
slice, pred, tail, ..
101+
} = self;
102+
103+
let (slice, tail) = if tail == slice.as_ref().len() {
104+
// No tail section, so just let `consume_iter` do it all.
105+
(Some(slice), None)
106+
} else if let Some(index) = slice.rfind(pred, tail) {
107+
// We found the last separator to complete the tail, so
108+
// end with that slice after `consume_iter` finds the rest.
109+
let (left, right) = slice.split(index);
110+
(Some(left), Some(right))
111+
} else {
112+
// We know there are no separators at all, so it's all "tail".
113+
(None, Some(slice))
114+
};
115+
116+
if let Some(mut slice) = slice {
117+
// TODO (MSRV 1.77) use either:
118+
// folder.consume_iter(slice.chunk_by(pred))
119+
// folder.consume_iter(slice.chunk_by_mut(pred))
120+
121+
folder = folder.consume_iter(std::iter::from_fn(move || {
122+
let len = slice.as_ref().len();
123+
if len > 0 {
124+
let i = slice.find(pred, 0, len).unwrap_or(len);
125+
let (head, tail) = mem::take(&mut slice).split(i);
126+
slice = tail;
127+
Some(head)
128+
} else {
129+
None
130+
}
131+
}));
132+
}
133+
134+
if let Some(tail) = tail {
135+
folder = folder.consume(tail);
136+
}
137+
138+
folder
139+
}
140+
}
141+
142+
/// Parallel iterator over slice in (non-overlapping) chunks separated by a predicate.
143+
///
144+
/// This struct is created by the [`par_chunk_by`] method on `&[T]`.
145+
///
146+
/// [`par_chunk_by`]: trait.ParallelSlice.html#method.par_chunk_by
147+
pub struct ChunkBy<'data, T, P> {
148+
pred: P,
149+
slice: &'data [T],
150+
}
151+
152+
impl<'data, T, P: Clone> Clone for ChunkBy<'data, T, P> {
153+
fn clone(&self) -> Self {
154+
ChunkBy {
155+
pred: self.pred.clone(),
156+
slice: self.slice,
157+
}
158+
}
159+
}
160+
161+
impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkBy<'data, T, P> {
162+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163+
f.debug_struct("ChunkBy")
164+
.field("slice", &self.slice)
165+
.finish()
166+
}
167+
}
168+
169+
impl<'data, T, P> ChunkBy<'data, T, P> {
170+
pub(super) fn new(slice: &'data [T], pred: P) -> Self {
171+
Self { pred, slice }
172+
}
173+
}
174+
175+
impl<'data, T, P> ParallelIterator for ChunkBy<'data, T, P>
176+
where
177+
T: Sync,
178+
P: Fn(&T, &T) -> bool + Send + Sync,
179+
{
180+
type Item = &'data [T];
181+
182+
fn drive_unindexed<C>(self, consumer: C) -> C::Result
183+
where
184+
C: UnindexedConsumer<Self::Item>,
185+
{
186+
bridge_unindexed(
187+
ChunkByProducer {
188+
tail: self.slice.len(),
189+
slice: self.slice,
190+
pred: &self.pred,
191+
marker: PhantomData,
192+
},
193+
consumer,
194+
)
195+
}
196+
}
197+
198+
/// Parallel iterator over slice in (non-overlapping) mutable chunks
199+
/// separated by a predicate.
200+
///
201+
/// This struct is created by the [`par_chunk_by_mut`] method on `&mut [T]`.
202+
///
203+
/// [`par_chunk_by_mut`]: trait.ParallelSliceMut.html#method.par_chunk_by_mut
204+
pub struct ChunkByMut<'data, T, P> {
205+
pred: P,
206+
slice: &'data mut [T],
207+
}
208+
209+
impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkByMut<'data, T, P> {
210+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211+
f.debug_struct("ChunkByMut")
212+
.field("slice", &self.slice)
213+
.finish()
214+
}
215+
}
216+
217+
impl<'data, T, P> ChunkByMut<'data, T, P> {
218+
pub(super) fn new(slice: &'data mut [T], pred: P) -> Self {
219+
Self { pred, slice }
220+
}
221+
}
222+
223+
impl<'data, T, P> ParallelIterator for ChunkByMut<'data, T, P>
224+
where
225+
T: Send,
226+
P: Fn(&T, &T) -> bool + Send + Sync,
227+
{
228+
type Item = &'data mut [T];
229+
230+
fn drive_unindexed<C>(self, consumer: C) -> C::Result
231+
where
232+
C: UnindexedConsumer<Self::Item>,
233+
{
234+
bridge_unindexed(
235+
ChunkByProducer {
236+
tail: self.slice.len(),
237+
slice: self.slice,
238+
pred: &self.pred,
239+
marker: PhantomData,
240+
},
241+
consumer,
242+
)
243+
}
244+
}

src/slice/mod.rs

+49
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//!
66
//! [std::slice]: https://doc.rust-lang.org/stable/std/slice/
77
8+
mod chunk_by;
89
mod chunks;
910
mod mergesort;
1011
mod quicksort;
@@ -22,6 +23,7 @@ use std::cmp::Ordering;
2223
use std::fmt::{self, Debug};
2324
use std::mem;
2425

26+
pub use self::chunk_by::{ChunkBy, ChunkByMut};
2527
pub use self::chunks::{Chunks, ChunksExact, ChunksExactMut, ChunksMut};
2628
pub use self::rchunks::{RChunks, RChunksExact, RChunksExactMut, RChunksMut};
2729

@@ -173,6 +175,29 @@ pub trait ParallelSlice<T: Sync> {
173175
assert!(chunk_size != 0, "chunk_size must not be zero");
174176
RChunksExact::new(chunk_size, self.as_parallel_slice())
175177
}
178+
179+
/// Returns a parallel iterator over the slice producing non-overlapping runs
180+
/// of elements using the predicate to separate them.
181+
///
182+
/// The predicate is called on two elements following themselves,
183+
/// it means the predicate is called on `slice[0]` and `slice[1]`
184+
/// then on `slice[1]` and `slice[2]` and so on.
185+
///
186+
/// # Examples
187+
///
188+
/// ```
189+
/// use rayon::prelude::*;
190+
/// let chunks: Vec<_> = [1, 2, 2, 3, 3, 3].par_chunk_by(|&x, &y| x == y).collect();
191+
/// assert_eq!(chunks[0], &[1]);
192+
/// assert_eq!(chunks[1], &[2, 2]);
193+
/// assert_eq!(chunks[2], &[3, 3, 3]);
194+
/// ```
195+
fn par_chunk_by<F>(&self, pred: F) -> ChunkBy<'_, T, F>
196+
where
197+
F: Fn(&T, &T) -> bool + Send + Sync,
198+
{
199+
ChunkBy::new(self.as_parallel_slice(), pred)
200+
}
176201
}
177202

178203
impl<T: Sync> ParallelSlice<T> for [T] {
@@ -704,6 +729,30 @@ pub trait ParallelSliceMut<T: Send> {
704729
{
705730
par_quicksort(self.as_parallel_slice_mut(), |a, b| f(a).lt(&f(b)));
706731
}
732+
733+
/// Returns a parallel iterator over the slice producing non-overlapping mutable
734+
/// runs of elements using the predicate to separate them.
735+
///
736+
/// The predicate is called on two elements following themselves,
737+
/// it means the predicate is called on `slice[0]` and `slice[1]`
738+
/// then on `slice[1]` and `slice[2]` and so on.
739+
///
740+
/// # Examples
741+
///
742+
/// ```
743+
/// use rayon::prelude::*;
744+
/// let mut xs = [1, 2, 2, 3, 3, 3];
745+
/// let chunks: Vec<_> = xs.par_chunk_by_mut(|&x, &y| x == y).collect();
746+
/// assert_eq!(chunks[0], &mut [1]);
747+
/// assert_eq!(chunks[1], &mut [2, 2]);
748+
/// assert_eq!(chunks[2], &mut [3, 3, 3]);
749+
/// ```
750+
fn par_chunk_by_mut<F>(&mut self, pred: F) -> ChunkByMut<'_, T, F>
751+
where
752+
F: Fn(&T, &T) -> bool + Send + Sync,
753+
{
754+
ChunkByMut::new(self.as_parallel_slice_mut(), pred)
755+
}
707756
}
708757

709758
impl<T: Send> ParallelSliceMut<T> for [T] {

src/slice/test.rs

+46
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use rand::distributions::Uniform;
55
use rand::seq::SliceRandom;
66
use rand::{thread_rng, Rng};
77
use std::cmp::Ordering::{Equal, Greater, Less};
8+
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
89

910
macro_rules! sort {
1011
($f:ident, $name:ident) => {
@@ -168,3 +169,48 @@ fn test_par_rchunks_exact_mut_remainder() {
168169
assert_eq!(c.take_remainder(), &[]);
169170
assert_eq!(c.len(), 2);
170171
}
172+
173+
#[test]
174+
fn slice_chunk_by() {
175+
let v: Vec<_> = (0..1000).collect();
176+
assert_eq!(v[..0].par_chunk_by(|_, _| todo!()).count(), 0);
177+
assert_eq!(v[..1].par_chunk_by(|_, _| todo!()).count(), 1);
178+
assert_eq!(v[..2].par_chunk_by(|_, _| true).count(), 1);
179+
assert_eq!(v[..2].par_chunk_by(|_, _| false).count(), 2);
180+
181+
let count = AtomicUsize::new(0);
182+
let par: Vec<_> = v
183+
.par_chunk_by(|x, y| {
184+
count.fetch_add(1, Relaxed);
185+
(x % 10 < 3) == (y % 10 < 3)
186+
})
187+
.collect();
188+
assert_eq!(count.into_inner(), v.len() - 1);
189+
190+
let seq: Vec<_> = v.chunk_by(|x, y| (x % 10 < 3) == (y % 10 < 3)).collect();
191+
assert_eq!(par, seq);
192+
}
193+
194+
#[test]
195+
fn slice_chunk_by_mut() {
196+
let mut v: Vec<_> = (0..1000).collect();
197+
assert_eq!(v[..0].par_chunk_by_mut(|_, _| todo!()).count(), 0);
198+
assert_eq!(v[..1].par_chunk_by_mut(|_, _| todo!()).count(), 1);
199+
assert_eq!(v[..2].par_chunk_by_mut(|_, _| true).count(), 1);
200+
assert_eq!(v[..2].par_chunk_by_mut(|_, _| false).count(), 2);
201+
202+
let mut v2 = v.clone();
203+
let count = AtomicUsize::new(0);
204+
let par: Vec<_> = v
205+
.par_chunk_by_mut(|x, y| {
206+
count.fetch_add(1, Relaxed);
207+
(x % 10 < 3) == (y % 10 < 3)
208+
})
209+
.collect();
210+
assert_eq!(count.into_inner(), v2.len() - 1);
211+
212+
let seq: Vec<_> = v2
213+
.chunk_by_mut(|x, y| (x % 10 < 3) == (y % 10 < 3))
214+
.collect();
215+
assert_eq!(par, seq);
216+
}

tests/clones.rs

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ fn clone_str() {
109109
fn clone_vec() {
110110
let v: Vec<_> = (0..1000).collect();
111111
check(v.par_iter());
112+
check(v.par_chunk_by(i32::eq));
112113
check(v.par_chunks(42));
113114
check(v.par_chunks_exact(42));
114115
check(v.par_rchunks(42));

tests/debug.rs

+2
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ fn debug_vec() {
121121
let mut v: Vec<_> = (0..10).collect();
122122
check(v.par_iter());
123123
check(v.par_iter_mut());
124+
check(v.par_chunk_by(i32::eq));
125+
check(v.par_chunk_by_mut(i32::eq));
124126
check(v.par_chunks(42));
125127
check(v.par_chunks_exact(42));
126128
check(v.par_chunks_mut(42));

0 commit comments

Comments
 (0)