|
1 |
| -use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Zip}; |
| 1 | +use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, Zip}; |
| 2 | +use rayon::prelude::*; |
2 | 3 |
|
3 | 4 | pub fn summator(
|
4 | 5 | cov_samples: ArrayView2<'_, f64>,
|
@@ -38,34 +39,47 @@ pub fn summator_incompr(
|
38 | 39 | assert_eq!(cov_samples.dim().1, z1.dim());
|
39 | 40 | assert_eq!(cov_samples.dim().1, z2.dim());
|
40 | 41 |
|
41 |
| - let mut summed_modes = Array2::<f64>::zeros(pos.dim()); |
42 |
| - |
43 | 42 | // unit vector in x dir.
|
44 | 43 | let mut e1 = Array1::<f64>::zeros(pos.dim().0);
|
45 | 44 | e1[0] = 1.0;
|
46 | 45 |
|
47 |
| - Zip::from(cov_samples.columns()) |
48 |
| - .and(z1) |
49 |
| - .and(z2) |
50 |
| - .for_each(|cov_samples, z1, z2| { |
51 |
| - let k_2 = cov_samples[0] / cov_samples.dot(&cov_samples); |
52 |
| - |
53 |
| - Zip::from(pos.columns()) |
54 |
| - .and(summed_modes.columns_mut()) |
55 |
| - .par_for_each(|pos, mut summed_modes| { |
56 |
| - let phase = cov_samples.dot(&pos); |
57 |
| - let z12 = z1 * phase.cos() + z2 * phase.sin(); |
58 |
| - |
59 |
| - Zip::from(&mut summed_modes) |
60 |
| - .and(&e1) |
61 |
| - .and(cov_samples) |
62 |
| - .for_each(|sum, e1, cs| { |
63 |
| - *sum += (*e1 - cs * k_2) * z12; |
64 |
| - }); |
65 |
| - }); |
66 |
| - }); |
67 |
| - |
68 |
| - summed_modes |
| 46 | + // FIXME: Use Zip::from(cov_samples.columns()).and(z1).and(z1) after |
| 47 | + // https://github.com/rust-ndarray/ndarray/pull/1081 is merged. |
| 48 | + cov_samples |
| 49 | + .axis_iter(Axis(1)) |
| 50 | + .into_par_iter() |
| 51 | + .zip(z1.axis_iter(Axis(0))) |
| 52 | + .zip(z2.axis_iter(Axis(0))) |
| 53 | + .with_min_len(100) |
| 54 | + .fold( |
| 55 | + || Array2::<f64>::zeros(pos.dim()), |
| 56 | + |mut summed_modes, ((cov_samples, z1), z2)| { |
| 57 | + let k_2 = cov_samples[0] / cov_samples.dot(&cov_samples); |
| 58 | + let z1 = z1.into_scalar(); |
| 59 | + let z2 = z2.into_scalar(); |
| 60 | + |
| 61 | + Zip::from(pos.columns()) |
| 62 | + .and(summed_modes.columns_mut()) |
| 63 | + .par_for_each(|pos, mut summed_modes| { |
| 64 | + let phase = cov_samples.dot(&pos); |
| 65 | + let z12 = z1 * phase.cos() + z2 * phase.sin(); |
| 66 | + |
| 67 | + Zip::from(&mut summed_modes) |
| 68 | + .and(&e1) |
| 69 | + .and(cov_samples) |
| 70 | + .for_each(|sum, e1, cs| { |
| 71 | + *sum += (*e1 - cs * k_2) * z12; |
| 72 | + }); |
| 73 | + }); |
| 74 | + |
| 75 | + summed_modes |
| 76 | + }, |
| 77 | + ) |
| 78 | + .reduce_with(|mut lhs, rhs| { |
| 79 | + lhs += &rhs; |
| 80 | + lhs |
| 81 | + }) |
| 82 | + .unwrap() |
69 | 83 | }
|
70 | 84 |
|
71 | 85 | #[cfg(test)]
|
|
0 commit comments