Skip to content

Commit f12842c

Browse files
authored
Fix edge case in AnchorTabular where no samples satisfying the anchor… (#742)
Fix edge case in AnchorTabular where no samples satisfying the anchor exist in the train data
1 parent 711bc16 commit f12842c

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

alibi/explainers/anchors/anchor_tabular.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ def perturbation(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.
290290
[allowed_rows[feat] for feat in uniq_feat_ids],
291291
np.intersect1d),
292292
)
293+
if partial_anchor_rows == []:
294+
# edge case - if there are no rows at all then `partial_anchor_rows` is the empty list, but it should
295+
# be a list of an empty array to not cause an error in calculating coverage (which will be 0)
296+
partial_anchor_rows = [np.array([], dtype=int)]
293297
nb_partial_anchors = np.array([len(n_records) for n_records in
294298
reversed(partial_anchor_rows)]) # reverse required for np.searchsorted later
295299
coverage = nb_partial_anchors[0] / self.n_records # since we sorted, the correct coverage is first not last
@@ -383,7 +387,14 @@ def replace_features(self, samples: np.ndarray, allowed_rows: Dict[int, Any], un
383387
requested_samples = num_samples
384388
start, n_anchor_feats = 0, len(partial_anchor_rows)
385389
uniq_feat_ids = list(reversed(uniq_feat_ids))
386-
start_idx = np.nonzero(nb_partial_anchors)[0][0] # skip anchors with no samples in the database
390+
391+
try:
392+
start_idx = np.nonzero(nb_partial_anchors)[0][0] # skip anchors with no samples in the database
393+
except IndexError:
394+
# there are no samples in the database, need to break out of the function
395+
# and go straight to treating unknown features
396+
return
397+
387398
end_idx = np.searchsorted(np.cumsum(nb_partial_anchors), num_samples)
388399

389400
# replace partial anchors with partial anchors drawn from the training dataset

0 commit comments

Comments
 (0)