Skip to content

Commit fc6d568

Browse files
Speed up functions by eliminating needless copies and asarrays, jit-ing functions
1 parent 9d06733 commit fc6d568

File tree

1 file changed

+111
-27
lines changed

1 file changed

+111
-27
lines changed

stingray/fourier.py

Lines changed: 111 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
sum_if_not_none_or_initialize,
2222
fix_segment_size_to_integer_samples,
2323
rebin_data,
24+
njit,
2425
)
2526

2627

@@ -1359,7 +1360,66 @@ def fun(times, gti, segment_size):
13591360
yield cts
13601361

13611362

1362-
def center_pds_on_f(freqs, powers, f0, nbins=100):
1363+
@njit()
1364+
def _safe_array_slice_indices(input_size, input_center_idx, nbins):
1365+
"""Calculate the indices needed to extract a n-bin slice of an array, centered at an index.
1366+
1367+
Let us say we have an array of size ``input_size`` and we want to extract a slice of
1368+
``nbins`` centered at index ``input_center_idx``. We should be robust when the slice goes
1369+
beyond the edges of the input array, possibly leaving missing values in the output array.
1370+
This function calculates the indices needed to extract the slice from the input array, and
1371+
the indices in the output array that will be filled.
1372+
1373+
In the most common case, the slice is entirely contained within the input array, so that the
1374+
output slice will just be ``[0:nbins]`` and the input slice
1375+
``[input_center_idx - nbins // 2: input_center_idx - nbins // 2 + nbins]``.
1376+
1377+
Parameters
1378+
----------
1379+
input_size : int
1380+
Input array size
1381+
center_idx : int
1382+
Index of the center of the slice
1383+
nbins : int
1384+
Number of bins to extract
1385+
1386+
Returns
1387+
-------
1388+
input_slice : list
1389+
Indices to extract the slice from the input array
1390+
output_slice : list
1391+
Indices to fill the output array
1392+
1393+
Examples
1394+
--------
1395+
>>> _safe_array_slice_indices(input_size=10, input_center_idx=5, nbins=3)
1396+
([4, 7], [0, 3])
1397+
1398+
If the slice goes beyond the right edge: the output slice will only cover
1399+
the first two bins of the output array, and up to the end of the input array.
1400+
>>> _safe_array_slice_indices(input_size=6, input_center_idx=5, nbins=3)
1401+
([4, 6], [0, 2])
1402+
1403+
"""
1404+
1405+
minbin = input_center_idx - nbins // 2
1406+
maxbin = minbin + nbins
1407+
1408+
if minbin < 0:
1409+
output_slice = [-minbin, min(nbins, input_size - minbin)]
1410+
input_slice = [0, minbin + nbins]
1411+
elif maxbin > input_size:
1412+
output_slice = [0, nbins - (maxbin - input_size)]
1413+
input_slice = [minbin, input_size]
1414+
else:
1415+
output_slice = [0, nbins]
1416+
input_slice = [minbin, maxbin]
1417+
1418+
return input_slice, output_slice
1419+
1420+
1421+
@njit()
1422+
def extract_pds_slice_around_freq(freqs, powers, f0, nbins=100):
13631423
"""Extract a slice of PDS around a given frequency.
13641424
13651425
This function extracts a slice of the power spectrum around a given frequency.
@@ -1385,27 +1445,57 @@ def center_pds_on_f(freqs, powers, f0, nbins=100):
13851445
>>> freqs = np.arange(1, 100) * 0.1
13861446
>>> powers = 10 / freqs
13871447
>>> f0 = 0.3
1388-
>>> p = center_pds_on_f(freqs, powers, f0)
1448+
>>> p = extract_pds_slice_around_freq(freqs, powers, f0)
13891449
>>> assert np.isnan(p[0])
13901450
>>> assert not np.any(np.isnan(p[48:]))
13911451
"""
13921452
powers = np.asarray(powers)
13931453
chunk = np.zeros(nbins) + np.nan
1394-
fchunk = np.zeros(nbins)
1454+
# fchunk = np.zeros(nbins)
13951455

13961456
start_f_idx = np.searchsorted(freqs, f0)
13971457

1398-
minbin = start_f_idx - nbins // 2
1399-
maxbin = minbin + nbins
1458+
input_slice, output_slice = _safe_array_slice_indices(powers.size, start_f_idx, nbins)
1459+
chunk[output_slice[0] : output_slice[1]] = powers[input_slice[0] : input_slice[1]]
1460+
return chunk
14001461

1401-
if minbin < 0:
1402-
chunk[-minbin : min(nbins, powers.size - minbin)] = powers[: minbin + nbins]
1403-
elif maxbin > powers.size:
1404-
chunk[: nbins - (maxbin - powers.size)] = powers[minbin:]
1405-
else:
1406-
chunk[:] = powers[minbin:maxbin]
14071462

1408-
return chunk
1463+
@njit()
1464+
def _shift_and_average_core(input_array_list, weight_list, center_indices, nbins):
1465+
"""Core function to shift_and_add, JIT-compiled for your convenience.
1466+
1467+
Parameters
1468+
----------
1469+
input_array_list : list of np.array
1470+
List of input arrays
1471+
weight_list : list of float
1472+
List of weights for each input array
1473+
center_indices : list of int
1474+
Central indices of the slice of each input array to be summed
1475+
nbins : int
1476+
Number of bins to extract around the central index of each input array
1477+
1478+
Returns
1479+
-------
1480+
output_array : np.array
1481+
Average of the input arrays, weighted by the weights
1482+
sum_of_weights : np.array
1483+
Sum of the weights at each output bin
1484+
"""
1485+
input_size = input_array_list[0].size
1486+
output_array = np.zeros(nbins)
1487+
sum_of_weights = np.zeros(nbins)
1488+
for idx, array, weight in zip(center_indices, input_array_list, weight_list):
1489+
input_slice, output_slice = _safe_array_slice_indices(input_size, idx, nbins)
1490+
1491+
for i in range(input_slice[1] - input_slice[0]):
1492+
output_array[output_slice[0] + i] += array[input_slice[0] + i] * weight
1493+
1494+
sum_of_weights[output_slice[0] + i] += weight
1495+
1496+
output_array = output_array / sum_of_weights
1497+
1498+
return output_array, sum_of_weights
14091499

14101500

14111501
def shift_and_add(freqs, power_list, f0_list, nbins=100, rebin=None, df=None, M=None):
@@ -1442,9 +1532,7 @@ def shift_and_add(freqs, power_list, f0_list, nbins=100, rebin=None, df=None, M=
14421532
Returns
14431533
-------
14441534
f : np.array
1445-
Array of output frequencies. This will be centered on the mean of the
1446-
input ``f0_list``, and have the same frequency resolution as the original
1447-
frequency array.
1535+
Array of output frequencies
14481536
p : np.array
14491537
Array of output powers
14501538
n : np.array
@@ -1460,32 +1548,28 @@ def shift_and_add(freqs, power_list, f0_list, nbins=100, rebin=None, df=None, M=
14601548
>>> assert np.array_equal(p, [2. , 2. , 5. , 2. , 1.5])
14611549
>>> assert np.allclose(f, [0.05, 0.15, 0.25, 0.35, 0.45])
14621550
"""
1463-
final_powers = np.zeros(nbins)
1551+
1552+
# Check if the input list of power contains numpy arrays
1553+
if not hasattr(power_list[0], "size"):
1554+
power_list = np.asarray(power_list)
1555+
# input_size = np.size(power_list[0])
14641556
freqs = np.asarray(freqs)
14651557

1466-
mid_idx = np.searchsorted(freqs, np.mean(f0_list))
1558+
# mid_idx = np.searchsorted(freqs, np.mean(f0_list))
14671559
if M is None:
14681560
M = 1
14691561
if not isinstance(M, Iterable):
14701562
M = np.ones(len(power_list)) * M
14711563

1472-
count = np.zeros(nbins)
1473-
for f0, powers, m in zip(f0_list, power_list, M):
1474-
idx = np.searchsorted(freqs, f0_list)
1564+
center_f_indices = np.searchsorted(freqs, f0_list)
14751565

1476-
powers = np.asarray(powers) * m
1477-
new_power = center_pds_on_f(freqs, powers, f0, nbins=nbins)
1478-
bad = np.isnan(new_power)
1479-
new_power[bad] = 0.0
1480-
final_powers += new_power
1481-
count += np.array(~bad, dtype=int) * m
1566+
final_powers, count = _shift_and_average_core(power_list, M, center_f_indices, nbins)
14821567

14831568
if df is None:
14841569
df = freqs[1] - freqs[0]
14851570

14861571
final_freqs = np.arange(-nbins // 2, nbins // 2 + 1)[:nbins] * df
14871572
final_freqs = final_freqs - (final_freqs[0] + final_freqs[-1]) / 2 + np.mean(f0_list)
1488-
final_powers = final_powers / count
14891573

14901574
if rebin is not None:
14911575
_, count, _, _ = rebin_data(final_freqs, count, rebin * df)

0 commit comments

Comments
 (0)