@@ -290,6 +290,10 @@ def perturbation(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.
290
290
[allowed_rows [feat ] for feat in uniq_feat_ids ],
291
291
np .intersect1d ),
292
292
)
293
+ if partial_anchor_rows == []:
294
+ # edge case - if there are no rows at all then `partial_anchor_rows` is the empty list, but it should
295
+ # be a list of an empty array to not cause an error in calculating coverage (which will be 0)
296
+ partial_anchor_rows = [np .array ([], dtype = int )]
293
297
nb_partial_anchors = np .array ([len (n_records ) for n_records in
294
298
reversed (partial_anchor_rows )]) # reverse required for np.searchsorted later
295
299
coverage = nb_partial_anchors [0 ] / self .n_records # since we sorted, the correct coverage is first not last
@@ -383,7 +387,14 @@ def replace_features(self, samples: np.ndarray, allowed_rows: Dict[int, Any], un
383
387
requested_samples = num_samples
384
388
start , n_anchor_feats = 0 , len (partial_anchor_rows )
385
389
uniq_feat_ids = list (reversed (uniq_feat_ids ))
386
- start_idx = np .nonzero (nb_partial_anchors )[0 ][0 ] # skip anchors with no samples in the database
390
+
391
+ try :
392
+ start_idx = np .nonzero (nb_partial_anchors )[0 ][0 ] # skip anchors with no samples in the database
393
+ except IndexError :
394
+ # there are no samples in the database, need to break out of the function
395
+ # and go straight to treating unknown features
396
+ return
397
+
387
398
end_idx = np .searchsorted (np .cumsum (nb_partial_anchors ), num_samples )
388
399
389
400
# replace partial anchors with partial anchors drawn from the training dataset
0 commit comments