Skip to content

Commit 5db36d9

Browse files
ALE custom grid points update. (#731)
* Loosen lower/upper bound conditions to allow custom grid points instead of clamping to min/max value. Included test. * Updated docs for min_bin_points. * Minor grammar correction.
1 parent dc55aba commit 5db36d9

File tree

2 files changed

+85
-15
lines changed

2 files changed

+85
-15
lines changed

alibi/explainers/ale.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def explain(self,
111111
Features for which to calculate ALE.
112112
min_bin_points
113113
Minimum number of points each discretized interval should contain to ensure more precise
114-
ALE estimation.
114+
ALE estimation. Only relevant for adaptive grid points (i.e., features without an entry in the
115+
`grid_points` dictionary).
115116
grid_points
116117
Custom grid points. Must be a `dict` where the keys are features indices and the values are
117118
monotonically increasing `numpy` arrays defining the grid points for each feature.
@@ -138,8 +139,8 @@ def explain(self,
138139
139140
- Grid points outside the feature range. Consider the following example: `O O O X X O X O X O O`, \
140141
where 3 grid-points are smaller than the minimum value in `f`, and 2 grid-points are larger than the maximum \
141-
value in `f`. Grid-points outside the feature value range are clipped between the minimum and maximum \
142-
values of `f`. The grid-points considered will be: `(O|X) X O X O (X|O)`.
142+
value in `f`. The empty leading and ending bins are removed. The grid-points considered
143+
will be: `O X X O X O X O`.
143144
144145
- Grid points that do not cover the entire feature range. Consider the following example: \
145146
`X X O X X O X O X X X X X`. Two auxiliary grid-points are added which correspond the value of the minimum \
@@ -408,7 +409,7 @@ def ale_num(
408409
Custom grid points. An `numpy` array defining the grid points for the given features.
409410
min_bin_points
410411
Minimum number of points each discretized interval should contain to ensure more precise
411-
ALE estimation.
412+
ALE estimation. Only relevant for adaptive grid points (i.e., feature for which ``feature_grid_points=None``).
412413
check_feature_resolution
413414
Refer to :class:`ALE` documentation.
414415
low_resolution_threshold
@@ -445,12 +446,24 @@ def ale_num(
445446
fvals = np.sort(feature_grid_points)
446447

447448
if min_val > fvals[0]:
448-
logger.warning(f'Feature {feature} grid-points contain lower values than the minimum feature value. '
449-
'Automatically lower bound clipping the grid-points values.')
449+
# select the greatest grid point that is less or equal to the minimum feature value
450+
min_idx = np.where(fvals <= min_val)[0][-1]
451+
min_val = fvals[min_idx]
452+
453+
if min_idx != 0:
454+
logger.warning(f'The leading bins of feature {feature} defined by the grid-points do not contain '
455+
'any feature values. Automatically removing the empty leading bins to ensure that '
456+
'each bin contains at least one feature value.')
450457

451458
if max_val < fvals[-1]:
452-
logger.warning(f'Feature {feature} grid-points contain larger values than the maximum feature value. '
453-
'Automatically upper bound clipping the grid-points values.')
459+
# select the smallest grid point that is larger or equal to the maximum feature value
460+
max_idx = np.where(fvals >= max_val)[0][0]
461+
max_val = fvals[max_idx]
462+
463+
if max_idx != len(fvals) - 1:
464+
logger.warning(f'The ending bins of feature {feature} defined by the grid-points do not contain '
465+
'any feature values. Automatically removing the empty ending bins to ensure that '
466+
'each bin contains at least one feature value.')
454467

455468
# clip the values and remove duplicates
456469
fvals = np.unique(np.clip(fvals, a_min=min_val, a_max=max_val))
@@ -469,13 +482,17 @@ def ale_num(
469482

470483
# check how many feature values are in each bin
471484
indices = np.searchsorted(fvals, X[:, feature], side="left")
472-
interval_n = np.bincount(indices) # number of points in each interval
485+
# put the smallest data point in the first interval
486+
indices[indices == 0] = 1
487+
# count the number of points in each interval without considering the first bin,
488+
# because the first bin will contain always 0 (see line above)
489+
interval_n = np.bincount(indices)[1:]
473490

474491
if np.any(interval_n == 0):
475-
fvals = np.delete(fvals, np.where(interval_n == 0)[0])
492+
fvals = np.delete(fvals, np.where(interval_n == 0)[0] + 1) # +1 because we don't consider the first bin
476493
logger.warning(f'Some bins of feature {feature} defined by the grid-points do not contain '
477494
'any feature values. Automatically merging consecutive bins to ensure that '
478-
'each bin contains at least one value.')
495+
'each bin contains at least one feature value.')
479496

480497
# if the feature is constant, calculate the ALE on a small interval surrounding the feature value
481498
if len(fvals) == 1:

alibi/explainers/tests/test_ale.py

+57-4
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,15 @@ def test_explain(mock_ale_explainer, features, input_dim, batch_size, custom_gri
137137
@pytest.mark.parametrize('extrapolate_constant_min', (0.1, 1.0))
138138
@pytest.mark.parametrize('constant_value', (5.,))
139139
@pytest.mark.parametrize('feature', (1,))
140-
@pytest.mark.parametrize('custom_grid', (True, False))
141140
def test_constant_feature(extrapolate_constant, extrapolate_constant_perc, extrapolate_constant_min,
142-
constant_value, feature, custom_grid):
141+
constant_value, feature):
143142
X = np.random.normal(size=(100, 2))
144143
X[:, feature] = constant_value
145144
predict = lambda x: x.sum(axis=1) # dummy predictor # noqa
146-
feature_grid_points = np.random.normal((1, )) if custom_grid else None
147145

148146
q, ale, ale0 = ale_num(predictor=predict,
149147
X=X,
150148
feature=feature,
151-
feature_grid_points=feature_grid_points,
152149
extrapolate_constant=extrapolate_constant,
153150
extrapolate_constant_perc=extrapolate_constant_perc,
154151
extrapolate_constant_min=extrapolate_constant_min)
@@ -159,3 +156,59 @@ def test_constant_feature(extrapolate_constant, extrapolate_constant_perc, extra
159156
assert_allclose(q, np.array([constant_value]))
160157
assert_allclose(ale, np.array([[0.]]))
161158
assert_allclose(ale0, np.array([0.]))
159+
160+
161+
@pytest.mark.parametrize('num_bins', [1, 3, 5, 7, 15])
162+
@pytest.mark.parametrize('perc_bins', [0.1, 0.2, 0.5, 0.7, 0.9, 1.0])
163+
@pytest.mark.parametrize('size_data', [1, 5, 10, 50, 100])
164+
@pytest.mark.parametrize('outside_grid', [False, True])
165+
def test_grid_points_stress(num_bins, perc_bins, size_data, outside_grid):
166+
np.random.seed(0)
167+
eps = 1
168+
169+
# define the grid between [-10, 10] having `num_bins` bins
170+
grid = np.unique(np.random.uniform(-10, 10, size=num_bins + 1))
171+
172+
# select specific bins to sample the data from grid defined above.
173+
# the number of bins is controlled by the percentage of bins given by `perc_bins`
174+
nbins = int(np.ceil(num_bins * perc_bins))
175+
bins = np.sort(np.random.choice(num_bins, size=nbins, replace=False))
176+
177+
# generate data
178+
X = []
179+
selected_bins = []
180+
181+
for i in range(size_data):
182+
# select a bin at random and mark it as selected
183+
bin = np.random.choice(bins, size=1)
184+
selected_bins.append(bin.item())
185+
186+
# define offset to ensure that the value is sampled within the bin
187+
# (i.e. avoid edge cases where the data might land on the grid point)
188+
# the ALE implementation should work even in that case, only the process of constructing
189+
# the expected values might require additional logic
190+
offset = 0.1 * (grid[bin + 1] - grid[bin])
191+
X.append(np.random.uniform(low=grid[bin] + offset, high=grid[bin + 1] - offset).item())
192+
193+
# add values outside the grid to test that the grid is extended
194+
if outside_grid:
195+
X = X + [grid[0] - eps, grid[-1] + eps]
196+
197+
# construct dataset, define dummy predictor, and get grid values used by ale
198+
X = np.array(X).reshape(-1, 1)
199+
predict = lambda x: x.sum(axis=1) # noqa
200+
q, _, _ = ale_num(predictor=predict, X=X, feature=0, feature_grid_points=grid)
201+
202+
# construct expected grid by merging selected bins
203+
if outside_grid:
204+
# add first and last bin corresponding to min and max value.
205+
# This requires incrementing all the previous values by 1
206+
selected_bins = np.array(selected_bins + [-1, num_bins]) + 1
207+
208+
# update grid point to include the min and max
209+
grid = np.insert(grid, 0, grid[0] - eps)
210+
grid = np.insert(grid, len(grid), grid[-1] + eps)
211+
212+
selected_bins = np.unique(selected_bins)
213+
expected_q = np.array([grid[selected_bins[0]]] + [grid[b + 1] for b in selected_bins])
214+
np.testing.assert_allclose(q, expected_q)

0 commit comments

Comments
 (0)