Skip to content

Commit 3cfa7cd

Browse files
authored
perf: Use optimize rolling_quantile with varying window sizes (#22353)
1 parent e66d3aa commit 3cfa7cd

File tree

17 files changed

+157
-123
lines changed

17 files changed

+157
-123
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ serde = { version = "1.0.188", features = ["derive", "rc"] }
7676
serde_json = "1"
7777
simd-json = { version = "0.14", features = ["known-key"] }
7878
simdutf8 = "0.1.4"
79+
skiplist = "0.5.1"
7980
slotmap = "1"
8081
sqlparser = "0.53"
8182
stacker = "0.1"

crates/polars-compute/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ polars-utils = { workspace = true }
2323
rand = { workspace = true }
2424
ryu = { workspace = true, optional = true }
2525
serde = { workspace = true, optional = true }
26+
skiplist = { workspace = true }
2627
strength_reduce = { workspace = true }
2728
strum_macros = { workspace = true }
2829

crates/polars-compute/src/rolling/min_max.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ impl<'a, T: NativeType, P: MinMaxPolicy> RollingAggWindowNulls<'a, T> for MinMax
6161
start: usize,
6262
end: usize,
6363
params: Option<RollingFnParams>,
64+
_window_size: Option<usize>,
6465
) -> Self {
6566
assert!(params.is_none());
6667
let mut slf = Self {
@@ -99,7 +100,13 @@ impl<'a, T: NativeType, P: MinMaxPolicy> RollingAggWindowNulls<'a, T> for MinMax
99100
}
100101

101102
impl<'a, T: NativeType, P: MinMaxPolicy> RollingAggWindowNoNulls<'a, T> for MinMaxWindow<'a, T, P> {
102-
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
103+
fn new(
104+
slice: &'a [T],
105+
start: usize,
106+
end: usize,
107+
params: Option<RollingFnParams>,
108+
_window_size: Option<usize>,
109+
) -> Self {
103110
assert!(params.is_none());
104111
let mut slf = Self {
105112
values: slice,

crates/polars-compute/src/rolling/no_nulls/mean.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@ impl<
2020
+ Sub<Output = T>,
2121
> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T>
2222
{
23-
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
23+
fn new(
24+
slice: &'a [T],
25+
start: usize,
26+
end: usize,
27+
params: Option<RollingFnParams>,
28+
window_size: Option<usize>,
29+
) -> Self {
2430
Self {
25-
sum: SumWindow::new(slice, start, end, params),
31+
sum: SumWindow::new(slice, start, end, params, window_size),
2632
}
2733
}
2834

crates/polars-compute/src/rolling/no_nulls/mod.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ pub use sum::*;
2020
use super::*;
2121

2222
pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
23-
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self;
23+
fn new(
24+
slice: &'a [T],
25+
start: usize,
26+
end: usize,
27+
params: Option<RollingFnParams>,
28+
window_size: Option<usize>,
29+
) -> Self;
2430

2531
/// Update and recompute the window
2632
///
@@ -44,7 +50,7 @@ where
4450
{
4551
let len = values.len();
4652
let (start, end) = det_offsets_fn(0, window_size, len);
47-
let mut agg_window = Agg::new(values, start, end, params);
53+
let mut agg_window = Agg::new(values, start, end, params, Some(window_size));
4854
if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
4955
if validity.iter().all(|x| !x) {
5056
return Ok(Box::new(PrimitiveArray::<T>::new_null(

crates/polars-compute/src/rolling/no_nulls/moment.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ impl<T: ToPrimitive + Copy, M: StateUpdate> MomentWindow<'_, T, M> {
2626
impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive, M: StateUpdate>
2727
RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M>
2828
{
29-
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
29+
fn new(
30+
slice: &'a [T],
31+
start: usize,
32+
end: usize,
33+
params: Option<RollingFnParams>,
34+
_window_size: Option<usize>,
35+
) -> Self {
3036
let mut out = Self {
3137
slice,
3238
moment: M::new(params),

crates/polars-compute/src/rolling/no_nulls/quantile.rs

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,28 @@ impl<
2828
+ Sub<Output = T>,
2929
> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
3030
{
31-
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
31+
fn new(
32+
slice: &'a [T],
33+
start: usize,
34+
end: usize,
35+
params: Option<RollingFnParams>,
36+
window_size: Option<usize>,
37+
) -> Self {
3238
let params = params.unwrap();
3339
let RollingFnParams::Quantile(params) = params else {
3440
unreachable!("expected Quantile params");
3541
};
3642

3743
Self {
38-
sorted: SortedBuf::new(slice, start, end),
44+
sorted: SortedBuf::new(slice, start, end, window_size),
3945
prob: params.prob,
4046
method: params.method,
4147
}
4248
}
4349

4450
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
45-
let vals = self.sorted.update(start, end);
46-
let length = vals.len();
51+
self.sorted.update(start, end);
52+
let length = self.sorted.len();
4753

4854
let idx = match self.method {
4955
Linear => {
@@ -54,11 +60,11 @@ impl<
5460
let float_idx_top = (length_f - 1.0) * self.prob;
5561
let top_idx = float_idx_top.ceil() as usize;
5662
return if idx == top_idx {
57-
Some(unsafe { *vals.get_unchecked(idx) })
63+
Some(self.sorted.get(idx))
5864
} else {
5965
let proportion = T::from(float_idx_top - idx as f64).unwrap();
60-
let vi = unsafe { *vals.get_unchecked(idx) };
61-
let vj = unsafe { *vals.get_unchecked(top_idx) };
66+
let vi = self.sorted.get(idx);
67+
let vj = self.sorted.get(top_idx);
6268

6369
Some(proportion * (vj - vi) + vi)
6470
};
@@ -70,14 +76,9 @@ impl<
7076

7177
let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
7278
return if top_idx == idx {
73-
// SAFETY:
74-
// we are in bounds
75-
Some(unsafe { *vals.get_unchecked(idx) })
79+
Some(self.sorted.get(idx))
7680
} else {
77-
// SAFETY:
78-
// we are in bounds
79-
let (mid, mid_plus_1) =
80-
unsafe { (*vals.get_unchecked(idx), *vals.get_unchecked(idx + 1)) };
81+
let (mid, mid_plus_1) = (self.sorted.get(idx), (self.sorted.get(idx + 1)));
8182

8283
Some((mid + mid_plus_1) / (T::one() + T::one()))
8384
};
@@ -94,9 +95,7 @@ impl<
9495
Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
9596
};
9697

97-
// SAFETY:
98-
// we are in bounds
99-
Some(unsafe { *vals.get_unchecked(idx) })
98+
Some(self.sorted.get(idx))
10099
}
101100
}
102101

crates/polars-compute/src/rolling/no_nulls/sum.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,13 @@ impl<
7575
+ Add<Output = T>,
7676
> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T>
7777
{
78-
fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
78+
fn new(
79+
slice: &'a [T],
80+
start: usize,
81+
end: usize,
82+
_params: Option<RollingFnParams>,
83+
_window_size: Option<usize>,
84+
) -> Self {
7985
let (sum, err) = sum_kahan(&slice[start..end]);
8086
Self {
8187
slice,

crates/polars-compute/src/rolling/nulls/mean.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ impl<
2323
start: usize,
2424
end: usize,
2525
params: Option<RollingFnParams>,
26+
window_size: Option<usize>,
2627
) -> Self {
2728
Self {
28-
sum: SumWindow::new(slice, validity, start, end, params),
29+
sum: SumWindow::new(slice, validity, start, end, params, window_size),
2930
}
3031
}
3132

crates/polars-compute/src/rolling/nulls/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub trait RollingAggWindowNulls<'a, T: NativeType> {
2222
start: usize,
2323
end: usize,
2424
params: Option<RollingFnParams>,
25+
window_size: Option<usize>,
2526
) -> Self;
2627

2728
/// # Safety
@@ -48,7 +49,8 @@ where
4849
let len = values.len();
4950
let (start, end) = det_offsets_fn(0, window_size, len);
5051
// SAFETY; we are in bounds
51-
let mut agg_window = unsafe { Agg::new(values, validity, start, end, params) };
52+
let mut agg_window =
53+
unsafe { Agg::new(values, validity, start, end, params, Some(window_size)) };
5254

5355
let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)
5456
.unwrap_or_else(|| {

crates/polars-compute/src/rolling/nulls/moment.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ impl<'a, T: NativeType + ToPrimitive + IsFloat + FromPrimitive, M: StateUpdate>
4545
start: usize,
4646
end: usize,
4747
params: Option<RollingFnParams>,
48+
_window_size: Option<usize>,
4849
) -> Self {
4950
let mut out = Self {
5051
slice,

crates/polars-compute/src/rolling/nulls/quantile.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,28 @@ impl<
3333
start: usize,
3434
end: usize,
3535
params: Option<RollingFnParams>,
36+
window_size: Option<usize>,
3637
) -> Self {
3738
let params = params.unwrap();
3839
let RollingFnParams::Quantile(params) = params else {
3940
unreachable!("expected Quantile params");
4041
};
4142
Self {
42-
sorted: SortedBufNulls::new(slice, validity, start, end),
43+
sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
4344
prob: params.prob,
4445
method: params.method,
4546
}
4647
}
4748

4849
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
49-
let (values, null_count) = self.sorted.update(start, end);
50+
let null_count = self.sorted.update(start, end);
51+
let mut length = self.sorted.len();
5052
// The min periods_issue will be taken care of when actually rolling
51-
if null_count == values.len() {
53+
if null_count == length {
5254
return None;
5355
}
5456
// Nulls are guaranteed to be at the front
55-
let values = &values[null_count..];
56-
let length = values.len();
57-
57+
length -= null_count;
5858
let mut idx = match self.method {
5959
QuantileMethod::Nearest => ((length as f64) * self.prob) as usize,
6060
QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
@@ -73,7 +73,8 @@ impl<
7373
QuantileMethod::Midpoint => {
7474
let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
7575
Some(
76-
(values.get_unchecked(idx).unwrap() + values.get_unchecked(top_idx).unwrap())
76+
(self.sorted.get(idx + null_count).unwrap()
77+
+ self.sorted.get(top_idx + null_count).unwrap())
7778
/ T::from::<f64>(2.0f64).unwrap(),
7879
)
7980
},
@@ -82,18 +83,18 @@ impl<
8283
let top_idx = f64::ceil(float_idx) as usize;
8384

8485
if top_idx == idx {
85-
Some(values.get_unchecked(idx).unwrap())
86+
Some(self.sorted.get(idx + null_count).unwrap())
8687
} else {
8788
let proportion = T::from(float_idx - idx as f64).unwrap();
8889
Some(
8990
proportion
90-
* (values.get_unchecked(top_idx).unwrap()
91-
- values.get_unchecked(idx).unwrap())
92-
+ values.get_unchecked(idx).unwrap(),
91+
* (self.sorted.get(top_idx + null_count).unwrap()
92+
- self.sorted.get(idx + null_count).unwrap())
93+
+ self.sorted.get(idx + null_count).unwrap(),
9394
)
9495
}
9596
},
96-
_ => Some(values.get_unchecked(idx).unwrap()),
97+
_ => Some(self.sorted.get(idx).unwrap()),
9798
}
9899
}
99100

crates/polars-compute/src/rolling/nulls/sum.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ impl<'a, T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + AddAssign
6969
start: usize,
7070
end: usize,
7171
_params: Option<RollingFnParams>,
72+
_window_size: Option<usize>,
7273
) -> Self {
7374
let mut out = Self {
7475
slice,

0 commit comments

Comments
 (0)