21
21
sum_if_not_none_or_initialize ,
22
22
fix_segment_size_to_integer_samples ,
23
23
rebin_data ,
24
+ njit ,
24
25
)
25
26
26
27
@@ -1359,7 +1360,66 @@ def fun(times, gti, segment_size):
1359
1360
yield cts
1360
1361
1361
1362
1362
- def center_pds_on_f (freqs , powers , f0 , nbins = 100 ):
1363
+ @njit ()
1364
+ def _safe_array_slice_indices (input_size , input_center_idx , nbins ):
1365
+ """Calculate the indices needed to extract a n-bin slice of an array, centered at an index.
1366
+
1367
+ Let us say we have an array of size ``input_size`` and we want to extract a slice of
1368
+ ``nbins`` centered at index ``input_center_idx``. We should be robust when the slice goes
1369
+ beyond the edges of the input array, possibly leaving missing values in the output array.
1370
+ This function calculates the indices needed to extract the slice from the input array, and
1371
+ the indices in the output array that will be filled.
1372
+
1373
+ In the most common case, the slice is entirely contained within the input array, so that the
1374
+ output slice will just be ``[0:nbins]`` and the input slice
1375
+ ``[input_center_idx - nbins // 2: input_center_idx - nbins // 2 + nbins]``.
1376
+
1377
+ Parameters
1378
+ ----------
1379
+ input_size : int
1380
+ Input array size
1381
+ center_idx : int
1382
+ Index of the center of the slice
1383
+ nbins : int
1384
+ Number of bins to extract
1385
+
1386
+ Returns
1387
+ -------
1388
+ input_slice : list
1389
+ Indices to extract the slice from the input array
1390
+ output_slice : list
1391
+ Indices to fill the output array
1392
+
1393
+ Examples
1394
+ --------
1395
+ >>> _safe_array_slice_indices(input_size=10, input_center_idx=5, nbins=3)
1396
+ ([4, 7], [0, 3])
1397
+
1398
+ If the slice goes beyond the right edge: the output slice will only cover
1399
+ the first two bins of the output array, and up to the end of the input array.
1400
+ >>> _safe_array_slice_indices(input_size=6, input_center_idx=5, nbins=3)
1401
+ ([4, 6], [0, 2])
1402
+
1403
+ """
1404
+
1405
+ minbin = input_center_idx - nbins // 2
1406
+ maxbin = minbin + nbins
1407
+
1408
+ if minbin < 0 :
1409
+ output_slice = [- minbin , min (nbins , input_size - minbin )]
1410
+ input_slice = [0 , minbin + nbins ]
1411
+ elif maxbin > input_size :
1412
+ output_slice = [0 , nbins - (maxbin - input_size )]
1413
+ input_slice = [minbin , input_size ]
1414
+ else :
1415
+ output_slice = [0 , nbins ]
1416
+ input_slice = [minbin , maxbin ]
1417
+
1418
+ return input_slice , output_slice
1419
+
1420
+
1421
+ @njit ()
1422
+ def extract_pds_slice_around_freq (freqs , powers , f0 , nbins = 100 ):
1363
1423
"""Extract a slice of PDS around a given frequency.
1364
1424
1365
1425
This function extracts a slice of the power spectrum around a given frequency.
@@ -1385,27 +1445,57 @@ def center_pds_on_f(freqs, powers, f0, nbins=100):
1385
1445
>>> freqs = np.arange(1, 100) * 0.1
1386
1446
>>> powers = 10 / freqs
1387
1447
>>> f0 = 0.3
1388
- >>> p = center_pds_on_f (freqs, powers, f0)
1448
+ >>> p = extract_pds_slice_around_freq (freqs, powers, f0)
1389
1449
>>> assert np.isnan(p[0])
1390
1450
>>> assert not np.any(np.isnan(p[48:]))
1391
1451
"""
1392
1452
powers = np .asarray (powers )
1393
1453
chunk = np .zeros (nbins ) + np .nan
1394
- fchunk = np .zeros (nbins )
1454
+ # fchunk = np.zeros(nbins)
1395
1455
1396
1456
start_f_idx = np .searchsorted (freqs , f0 )
1397
1457
1398
- minbin = start_f_idx - nbins // 2
1399
- maxbin = minbin + nbins
1458
+ input_slice , output_slice = _safe_array_slice_indices (powers .size , start_f_idx , nbins )
1459
+ chunk [output_slice [0 ] : output_slice [1 ]] = powers [input_slice [0 ] : input_slice [1 ]]
1460
+ return chunk
1400
1461
1401
- if minbin < 0 :
1402
- chunk [- minbin : min (nbins , powers .size - minbin )] = powers [: minbin + nbins ]
1403
- elif maxbin > powers .size :
1404
- chunk [: nbins - (maxbin - powers .size )] = powers [minbin :]
1405
- else :
1406
- chunk [:] = powers [minbin :maxbin ]
1407
1462
1408
- return chunk
1463
+ @njit ()
1464
+ def _shift_and_average_core (input_array_list , weight_list , center_indices , nbins ):
1465
+ """Core function to shift_and_add, JIT-compiled for your convenience.
1466
+
1467
+ Parameters
1468
+ ----------
1469
+ input_array_list : list of np.array
1470
+ List of input arrays
1471
+ weight_list : list of float
1472
+ List of weights for each input array
1473
+ center_indices : list of int
1474
+ Central indices of the slice of each input array to be summed
1475
+ nbins : int
1476
+ Number of bins to extract around the central index of each input array
1477
+
1478
+ Returns
1479
+ -------
1480
+ output_array : np.array
1481
+ Average of the input arrays, weighted by the weights
1482
+ sum_of_weights : np.array
1483
+ Sum of the weights at each output bin
1484
+ """
1485
+ input_size = input_array_list [0 ].size
1486
+ output_array = np .zeros (nbins )
1487
+ sum_of_weights = np .zeros (nbins )
1488
+ for idx , array , weight in zip (center_indices , input_array_list , weight_list ):
1489
+ input_slice , output_slice = _safe_array_slice_indices (input_size , idx , nbins )
1490
+
1491
+ for i in range (input_slice [1 ] - input_slice [0 ]):
1492
+ output_array [output_slice [0 ] + i ] += array [input_slice [0 ] + i ] * weight
1493
+
1494
+ sum_of_weights [output_slice [0 ] + i ] += weight
1495
+
1496
+ output_array = output_array / sum_of_weights
1497
+
1498
+ return output_array , sum_of_weights
1409
1499
1410
1500
1411
1501
def shift_and_add (freqs , power_list , f0_list , nbins = 100 , rebin = None , df = None , M = None ):
@@ -1442,9 +1532,7 @@ def shift_and_add(freqs, power_list, f0_list, nbins=100, rebin=None, df=None, M=
1442
1532
Returns
1443
1533
-------
1444
1534
f : np.array
1445
- Array of output frequencies. This will be centered on the mean of the
1446
- input ``f0_list``, and have the same frequency resolution as the original
1447
- frequency array.
1535
+ Array of output frequencies
1448
1536
p : np.array
1449
1537
Array of output powers
1450
1538
n : np.array
@@ -1460,32 +1548,28 @@ def shift_and_add(freqs, power_list, f0_list, nbins=100, rebin=None, df=None, M=
1460
1548
>>> assert np.array_equal(p, [2. , 2. , 5. , 2. , 1.5])
1461
1549
>>> assert np.allclose(f, [0.05, 0.15, 0.25, 0.35, 0.45])
1462
1550
"""
1463
- final_powers = np .zeros (nbins )
1551
+
1552
+ # Check if the input list of power contains numpy arrays
1553
+ if not hasattr (power_list [0 ], "size" ):
1554
+ power_list = np .asarray (power_list )
1555
+ # input_size = np.size(power_list[0])
1464
1556
freqs = np .asarray (freqs )
1465
1557
1466
- mid_idx = np .searchsorted (freqs , np .mean (f0_list ))
1558
+ # mid_idx = np.searchsorted(freqs, np.mean(f0_list))
1467
1559
if M is None :
1468
1560
M = 1
1469
1561
if not isinstance (M , Iterable ):
1470
1562
M = np .ones (len (power_list )) * M
1471
1563
1472
- count = np .zeros (nbins )
1473
- for f0 , powers , m in zip (f0_list , power_list , M ):
1474
- idx = np .searchsorted (freqs , f0_list )
1564
+ center_f_indices = np .searchsorted (freqs , f0_list )
1475
1565
1476
- powers = np .asarray (powers ) * m
1477
- new_power = center_pds_on_f (freqs , powers , f0 , nbins = nbins )
1478
- bad = np .isnan (new_power )
1479
- new_power [bad ] = 0.0
1480
- final_powers += new_power
1481
- count += np .array (~ bad , dtype = int ) * m
1566
+ final_powers , count = _shift_and_average_core (power_list , M , center_f_indices , nbins )
1482
1567
1483
1568
if df is None :
1484
1569
df = freqs [1 ] - freqs [0 ]
1485
1570
1486
1571
final_freqs = np .arange (- nbins // 2 , nbins // 2 + 1 )[:nbins ] * df
1487
1572
final_freqs = final_freqs - (final_freqs [0 ] + final_freqs [- 1 ]) / 2 + np .mean (f0_list )
1488
- final_powers = final_powers / count
1489
1573
1490
1574
if rebin is not None :
1491
1575
_ , count , _ , _ = rebin_data (final_freqs , count , rebin * df )
0 commit comments