@@ -137,18 +137,15 @@ def test_explain(mock_ale_explainer, features, input_dim, batch_size, custom_gri
137
137
@pytest .mark .parametrize ('extrapolate_constant_min' , (0.1 , 1.0 ))
138
138
@pytest .mark .parametrize ('constant_value' , (5. ,))
139
139
@pytest .mark .parametrize ('feature' , (1 ,))
140
- @pytest .mark .parametrize ('custom_grid' , (True , False ))
141
140
def test_constant_feature (extrapolate_constant , extrapolate_constant_perc , extrapolate_constant_min ,
142
- constant_value , feature , custom_grid ):
141
+ constant_value , feature ):
143
142
X = np .random .normal (size = (100 , 2 ))
144
143
X [:, feature ] = constant_value
145
144
predict = lambda x : x .sum (axis = 1 ) # dummy predictor # noqa
146
- feature_grid_points = np .random .normal ((1 , )) if custom_grid else None
147
145
148
146
q , ale , ale0 = ale_num (predictor = predict ,
149
147
X = X ,
150
148
feature = feature ,
151
- feature_grid_points = feature_grid_points ,
152
149
extrapolate_constant = extrapolate_constant ,
153
150
extrapolate_constant_perc = extrapolate_constant_perc ,
154
151
extrapolate_constant_min = extrapolate_constant_min )
@@ -159,3 +156,59 @@ def test_constant_feature(extrapolate_constant, extrapolate_constant_perc, extra
159
156
assert_allclose (q , np .array ([constant_value ]))
160
157
assert_allclose (ale , np .array ([[0. ]]))
161
158
assert_allclose (ale0 , np .array ([0. ]))
159
+
160
+
161
+ @pytest .mark .parametrize ('num_bins' , [1 , 3 , 5 , 7 , 15 ])
162
+ @pytest .mark .parametrize ('perc_bins' , [0.1 , 0.2 , 0.5 , 0.7 , 0.9 , 1.0 ])
163
+ @pytest .mark .parametrize ('size_data' , [1 , 5 , 10 , 50 , 100 ])
164
+ @pytest .mark .parametrize ('outside_grid' , [False , True ])
165
+ def test_grid_points_stress (num_bins , perc_bins , size_data , outside_grid ):
166
+ np .random .seed (0 )
167
+ eps = 1
168
+
169
+ # define the grid between [-10, 10] having `num_bins` bins
170
+ grid = np .unique (np .random .uniform (- 10 , 10 , size = num_bins + 1 ))
171
+
172
+ # select specific bins to sample the data from grid defined above.
173
+ # the number of bins is controlled by the percentage of bins given by `perc_bins`
174
+ nbins = int (np .ceil (num_bins * perc_bins ))
175
+ bins = np .sort (np .random .choice (num_bins , size = nbins , replace = False ))
176
+
177
+ # generate data
178
+ X = []
179
+ selected_bins = []
180
+
181
+ for i in range (size_data ):
182
+ # select a bin at random and mark it as selected
183
+ bin = np .random .choice (bins , size = 1 )
184
+ selected_bins .append (bin .item ())
185
+
186
+ # define offset to ensure that the value is sampled within the bin
187
+ # (i.e. avoid edge cases where the data might land on the grid point)
188
+ # the ALE implementation should work even in that case, only the process of constructing
189
+ # the expected values might require additional logic
190
+ offset = 0.1 * (grid [bin + 1 ] - grid [bin ])
191
+ X .append (np .random .uniform (low = grid [bin ] + offset , high = grid [bin + 1 ] - offset ).item ())
192
+
193
+ # add values outside the grid to test that the grid is extended
194
+ if outside_grid :
195
+ X = X + [grid [0 ] - eps , grid [- 1 ] + eps ]
196
+
197
+ # construct dataset, define dummy predictor, and get grid values used by ale
198
+ X = np .array (X ).reshape (- 1 , 1 )
199
+ predict = lambda x : x .sum (axis = 1 ) # noqa
200
+ q , _ , _ = ale_num (predictor = predict , X = X , feature = 0 , feature_grid_points = grid )
201
+
202
+ # construct expected grid by merging selected bins
203
+ if outside_grid :
204
+ # add first and last bin corresponding to min and max value.
205
+ # This requires incrementing all the previous values by 1
206
+ selected_bins = np .array (selected_bins + [- 1 , num_bins ]) + 1
207
+
208
+ # update grid point to include the min and max
209
+ grid = np .insert (grid , 0 , grid [0 ] - eps )
210
+ grid = np .insert (grid , len (grid ), grid [- 1 ] + eps )
211
+
212
+ selected_bins = np .unique (selected_bins )
213
+ expected_q = np .array ([grid [selected_bins [0 ]]] + [grid [b + 1 ] for b in selected_bins ])
214
+ np .testing .assert_allclose (q , expected_q )
0 commit comments