-
Notifications
You must be signed in to change notification settings - Fork 20
Parallelize Rust function apply_phase_shift_in_place #230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hey @kevinsung I came up with a possible implementation which most probably has worse performance per iteration but the We can iterate over let mut vec = vec.as_array_mut();
let indices = indices.as_array();
let shape = vec.shape();
let dim_b = shape[1] as i32;
vec.axis_iter_mut(Axis(0)).into_par_iter()
.enumerate()
.filter(|(i, _)| indices.iter().any(|idx| idx == i))
.for_each(|(_, mut row)| {
match row.as_slice_mut() {
Some(row) => unsafe {
zscal(dim_b, phase, row, 1);
},
None => panic!(
"Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
),
}
})
Is there a script which I can use to test the performance of I am wondering what you implementation idea was since I had no issues with the unsafe block. |
Hi @S-Erik, thank you for looking into this! Here is a script that you can adapt to test the performance: import cmath
import numpy as np
from ffsim._lib import apply_phase_shift_in_place
rng = np.random.default_rng(1234)
dim = 100
n_indices = 50
mat = rng.standard_normal((dim, dim)).astype(complex)
phase_shift = cmath.rect(1, rng.uniform(0, np.pi))
indices = rng.choice(dim, size=n_indices, replace=False).astype(np.uint64)
apply_phase_shift_in_place(mat, phase_shift, indices)
To be honest I don't remember at this point. Maybe I was mistaken. |
Thanks for the quick answer. I tested the performance of the current implementation against the version I suggested but slightly changed to make it more readable (call to let mut vec = vec.as_array_mut();
let indices = indices.as_array().to_vec();
let shape = vec.shape();
let dim_b = shape[1] as i32;
vec.axis_iter_mut(Axis(0)).into_par_iter()
.enumerate()
.filter(|(i, _)| indices.contains(i))
.map(|(_, row)| row)
.for_each(|mut row| {
match row.as_slice_mut() {
Some(row) => unsafe {
zscal(dim_b, phase, row, 1);
},
None => panic!(
"Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
),
}
}) For this I timed different number of indices (
We see slight performance improvements for larger matrices and larger indices with my The main challenge modifying the current implementation to use concurrent calls to I found a stackoverflow discussion about a very similar problem. There it was also suggested to use my @kevinsung what are your thoughts on that? Currently, I am not eager to implement a concurrent version without rayon. Other ApproachesI also tried different approaches:
let mut vec = vec.as_array_mut();
let indices = indices.as_array().to_vec();
let shape = vec.shape();
let dim_b = shape[1] as i32;
let rows_bool: Vec<bool> = (0..vec.len_of(Axis(0)))
.map(|i| if indices.contains(&i) { true } else { false })
.collect();
vec.axis_iter_mut(Axis(0)).into_par_iter().zip(rows_bool).for_each(|(mut row, bool_val)| {
if bool_val {
match row.as_slice_mut() {
Some(row) => unsafe {
zscal(dim_b, phase, row, 1);
},
None => panic!(
"Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
),
}
}
})
let mut vec = vec.as_array_mut();
let indices = indices.as_array();
let shape = vec.shape();
let dim_b = shape[1] as i32;
let indices_mapped: Vec<Array1<Complex64>> = indices.into_par_iter().map(|&str0| {
let mut target = vec.row(str0).to_owned();
match target.as_slice_mut() {
Some(target) => unsafe {
zscal(dim_b, phase, target, 1);
},
None => panic!(
"Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
),
};
target
}).collect();
indices
.into_iter()
.zip(indices_mapped)
.for_each(|(&str0, val)| {
let mut target = vec.row_mut(str0);
target.assign(&val);
}) each of which where slower then the My CPU info (
|
@S-Erik Thank you very much for your investigation!
You are exactly right about this. I've updated the opening post to reflect this underlying issue more accurately.
Makes sense.
I think we should implement the threading manually rather than use your |
I was able to implement a concurrent version with manual threading using unsafe pointer de-referencing (see code below in section "Rust code"). For this I currently hard-coded the number of threads. This approach uses a bit more memory because the I benchmarked this manual threading approach against the current implementation and the filter approach from above (see python benchmark script in "Python script" section). For this I timed different number of indices (
We see similar (small) improvements as in the filter-approach relative to the current implementation. Maybe for larger matrices and more indices we would see larger performance improvements. Unfortunately I am not able to test larger matrices since I would run out-of-memory. I have to say that I am quite disappointed by the manual-threading implementation. I expected a large improvement (at least 2x) compared to the current implementation. Probably the overhead of creating threads here is a big relative performance-hit since the calculations performed in the threads are quite fast. What are your thoughts @kevinsung? I will also try to implement the manual-threading approach for the issue #229 now that I know how to do it. Maybe we get a bigger performance gain there. Rust codeThis code passes all tests when I run /// Apply a phase shift to slices of a state vector.
#[pyfunction]
pub fn apply_phase_shift_in_place(
mut vec: PyReadwriteArray2<Complex64>,
phase: Complex64,
indices: PyReadonlyArray1<usize>,
) {
let mut vec = vec.as_array_mut();
let indices = indices.as_array().to_vec();
let shape = vec.shape();
let dim_b = shape[1] as i32;
let num_threads = 5;
if indices.len() == 0 {
return;
}
let mut chunk_size = indices.len();
if indices.len() >= num_threads {
chunk_size = indices.len() / num_threads;
}
let mut handles = vec![];
// With "as usize" the address of the pointer is stored in a usize.
// With this address we can create the pointer again.
// This is necessary since a pointer is not Send but a usize is
let ptr_usize = vec.as_mut_ptr() as usize;
for chunk in indices.chunks(chunk_size) {
let chunk_owned = chunk.to_vec();
let handle = thread::spawn(move || unsafe {
for str0 in chunk_owned {
let row_ptr =
(ptr_usize as *mut Complex64).offset((str0 as isize) * dim_b as isize);
let target = std::slice::from_raw_parts_mut(row_ptr, dim_b as usize);
zscal(dim_b, phase, target, 1);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
} Python scriptimport time
import cmath
import numpy as np
import matplotlib.pyplot as plt
from ffsim._lib import apply_phase_shift_in_place
rng = np.random.default_rng(1234)
n_lst = np.arange(50, 90, 10) * 100
dim_lst = np.arange(80, 121, 20) * 100
print(f"n_lst: {n_lst}")
print(f"dim_lst: {dim_lst}")
n = 100
mean_times = {} # key is n_indices
for i, n_indices in enumerate(n_lst):
print(f"Using {n_indices} indices ({i+1}/{len(n_lst)})...", end="\r")
mean_times[n_indices] = []
for dim in dim_lst:
mat = rng.standard_normal((dim, dim)).astype(
complex
) + 1.0j * rng.standard_normal((dim, dim)).astype(complex)
phase_shift = cmath.rect(1, rng.uniform(0, np.pi))
indices = rng.choice(dim, size=n_indices, replace=False).astype(np.uint64)
time_sum = 0
for _ in range(n):
start_time = time.perf_counter()
apply_phase_shift_in_place(mat, phase_shift, indices)
time_sum += time.perf_counter() - start_time
# print(f"Took {(time_sum)/n}s per loop.")
mean_times[n_indices].append(time_sum / n)
plt.figure()
for key, vals in mean_times.items():
plt.plot(dim_lst, vals, label=f"{key} indices", marker=".")
plt.ylabel(f"Mean runtime of {n} runs [s]")
plt.xlabel(f"Dimension of matrix")
plt.yscale("log")
plt.legend()
plt.grid()
plt.savefig("perf.png", bbox_inches="tight", dpi=128) |
This function here:
ffsim/src/gates/phase_shift.rs
Line 32 in 1e55524
A straightforward attempt doesn't pass the compiler
due to the use of unsafe BLAS functionsbecause each thread needs to have a mutable reference to a row of the array being modified. We know that no two threads will have access to the same row, but the compiler can't tell.The text was updated successfully, but these errors were encountered: