Skip to content

Commit 8561a04

Browse files
adamreicholdLSchueler
authored andcommitted
Redo parallel implementation of summator_incompr
This is a speed-up due to the call to `with_min_len` which however requires indexed parallel iterators which `ndarray` provides only via `axis_iter` and `axis_chunk_iter`, so this workaround is necessary until [1] is merged. [1] rust-ndarray/ndarray#1081
1 parent ca40a5e commit 8561a04

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

src/field.rs

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Zip};
1+
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, Zip};
2+
use rayon::prelude::*;
23

34
pub fn summator(
45
cov_samples: ArrayView2<'_, f64>,
@@ -38,34 +39,47 @@ pub fn summator_incompr(
3839
assert_eq!(cov_samples.dim().1, z1.dim());
3940
assert_eq!(cov_samples.dim().1, z2.dim());
4041

41-
let mut summed_modes = Array2::<f64>::zeros(pos.dim());
42-
4342
// unit vector in x dir.
4443
let mut e1 = Array1::<f64>::zeros(pos.dim().0);
4544
e1[0] = 1.0;
4645

47-
Zip::from(cov_samples.columns())
48-
.and(z1)
49-
.and(z2)
50-
.for_each(|cov_samples, z1, z2| {
51-
let k_2 = cov_samples[0] / cov_samples.dot(&cov_samples);
52-
53-
Zip::from(pos.columns())
54-
.and(summed_modes.columns_mut())
55-
.par_for_each(|pos, mut summed_modes| {
56-
let phase = cov_samples.dot(&pos);
57-
let z12 = z1 * phase.cos() + z2 * phase.sin();
58-
59-
Zip::from(&mut summed_modes)
60-
.and(&e1)
61-
.and(cov_samples)
62-
.for_each(|sum, e1, cs| {
63-
*sum += (*e1 - cs * k_2) * z12;
64-
});
65-
});
66-
});
67-
68-
summed_modes
46+
// FIXME: Use Zip::from(cov_samples.columns()).and(z1).and(z1) after
47+
// https://github.com/rust-ndarray/ndarray/pull/1081 is merged.
48+
cov_samples
49+
.axis_iter(Axis(1))
50+
.into_par_iter()
51+
.zip(z1.axis_iter(Axis(0)))
52+
.zip(z2.axis_iter(Axis(0)))
53+
.with_min_len(100)
54+
.fold(
55+
|| Array2::<f64>::zeros(pos.dim()),
56+
|mut summed_modes, ((cov_samples, z1), z2)| {
57+
let k_2 = cov_samples[0] / cov_samples.dot(&cov_samples);
58+
let z1 = z1.into_scalar();
59+
let z2 = z2.into_scalar();
60+
61+
Zip::from(pos.columns())
62+
.and(summed_modes.columns_mut())
63+
.par_for_each(|pos, mut summed_modes| {
64+
let phase = cov_samples.dot(&pos);
65+
let z12 = z1 * phase.cos() + z2 * phase.sin();
66+
67+
Zip::from(&mut summed_modes)
68+
.and(&e1)
69+
.and(cov_samples)
70+
.for_each(|sum, e1, cs| {
71+
*sum += (*e1 - cs * k_2) * z12;
72+
});
73+
});
74+
75+
summed_modes
76+
},
77+
)
78+
.reduce_with(|mut lhs, rhs| {
79+
lhs += &rhs;
80+
lhs
81+
})
82+
.unwrap()
6983
}
7084

7185
#[cfg(test)]

0 commit comments

Comments
 (0)