Skip to content

Commit 3c0fd87

Browse files
mmschlkCopilot
andauthored
fixes bug with a TabPFN model not working as intended after explanations (#401)
* works on fixing the bug * reduced size of tabpfn model * fixes #396 * documents fix in CHANGELOG.md * Update tests/tests_imputer/test_tabpfn_imputer.py Co-authored-by: Copilot <[email protected]> * fixed code-quality checks --------- Co-authored-by: Copilot <[email protected]>
1 parent 4b487a8 commit 3c0fd87

File tree

5 files changed

+78
-52
lines changed

5 files changed

+78
-52
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- suppress a ``RuntimeWarning`` in ``Regression`` approximators ``solve_regression()``method when the solver is not able to find good interim solutions for the regression problem.
2222
#### Bug Fixes
2323
- fixed a bug in the `shapiq.waterfall_plot` function that caused the plot to not display correctly resulting in cutoff y_ticks. Additionally, the file was renamed from `watefall.py` to `waterfall.py` to match the function name [#377](https://github.com/mmschlk/shapiq/pull/377)
24+
- fixes a bug with `TabPFNExplainer`, where the model was not able to be used for predictions after it was explained. This was due to the model being fitted on a subset of features, which caused inconsistencies in the model's predictions after explanation. The fix includes that after each call to the `TabPFNImputer.value_function`, the tabpfn model is fitted on the whole dataset (without omitting features). This means that the original model can be used for predictions after it has been explained. [#396](https://github.com/mmschlk/shapiq/issues/396).
2425

2526
### v1.2.3 (2025-03-24)
2627
- substantially improves the runtime of all `Regression` approximators by a) a faster pre-computation of the regression matrices and b) a faster computation of the weighted least squares regression [#340](https://github.com/mmschlk/shapiq/issues/340)

shapiq/games/imputer/tabpfn_imputer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
accept the model and the data point as input and return the model's predictions. If
7878
the model is instantiated via a ``shapiq.Explainer`` object, this function is
7979
automatically set to the model's prediction function. Defaults to ``None``.
80+
8081
"""
8182
self.x_train = x_train
8283
self.y_train = y_train
@@ -136,4 +137,6 @@ def value_function(self, coalitions: np.ndarray) -> np.ndarray:
136137
self.model.fit(x_train_coal, self.y_train)
137138
pred = float(self.predict(x_explain_coal))
138139
output[i] = pred
140+
# refit the model on the full training data to ensure it is in a consistent state
141+
self.model.fit(self.x_train, self.y_train)
139142
return output

tests/fixtures/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def tabpfn_classification_problem(
185185

186186
data, labels = background_clf_dataset_binary_small
187187
data, x_test, labels, _ = train_test_split(data, labels, random_state=42, train_size=8)
188-
model = tabpfn.TabPFNClassifier()
188+
model = tabpfn.TabPFNClassifier(n_estimators=1, fit_mode="low_memory")
189189
model.fit(data, labels)
190190
return model, data, labels, x_test
191191

@@ -199,7 +199,7 @@ def tabpfn_regression_problem(
199199

200200
data, labels = background_reg_dataset_small
201201
data, x_test, labels, _ = train_test_split(data, labels, random_state=42, train_size=8)
202-
model = tabpfn.TabPFNRegressor()
202+
model = tabpfn.TabPFNRegressor(n_estimators=1, fit_mode="low_memory")
203203
model.fit(data, labels)
204204
return model, data, labels, x_test
205205

tests/tests_explainer/test_explainer_tabpfn.py

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,55 +11,77 @@
1111

1212
@skip_if_no_tabpfn
1313
@pytest.mark.external_libraries
14-
def test_tabpfn_explainer_clf(tabpfn_classification_problem):
15-
"""Test the TabPFNExplainer class for classification problems."""
16-
import tabpfn
14+
class TestTabPFNExplainer:
15+
"""Tests for the TabPFNExplainer class."""
1716

18-
# setup
19-
model, data, labels, x_test = tabpfn_classification_problem
20-
x_explain = x_test[0]
21-
assert isinstance(model, tabpfn.TabPFNClassifier)
22-
if model.n_features_in_ == data.shape[1]:
23-
model.fit(data, labels)
24-
assert model.n_features_in_ == data.shape[1]
17+
def test_tabpfn_explainer_clf(self, tabpfn_classification_problem):
18+
"""Test the TabPFNExplainer class for classification problems."""
19+
import tabpfn
2520

26-
explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test)
27-
explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL)
28-
assert isinstance(explanation, InteractionValues)
21+
# setup
22+
model, data, labels, x_test = tabpfn_classification_problem
23+
x_explain = x_test[0]
24+
assert isinstance(model, tabpfn.TabPFNClassifier)
25+
if model.n_features_in_ == data.shape[1]:
26+
model.fit(data, labels)
27+
assert model.n_features_in_ == data.shape[1]
2928

30-
# test that bare explainer gets turned into TabPFNExplainer
31-
explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test)
32-
assert isinstance(explainer, TabPFNExplainer)
29+
explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test)
30+
explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL)
31+
assert isinstance(explanation, InteractionValues)
3332

34-
# test that TabularExplainer works as well
35-
with pytest.warns(UserWarning):
36-
explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline")
37-
assert isinstance(explainer, TabularExplainer)
33+
# test that bare explainer gets turned into TabPFNExplainer
34+
explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test)
35+
assert isinstance(explainer, TabPFNExplainer)
36+
37+
# test that TabularExplainer works as well
38+
with pytest.warns(UserWarning):
39+
explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline")
40+
assert isinstance(explainer, TabularExplainer)
41+
42+
def test_tabpfn_explainer_reg(self, tabpfn_regression_problem):
43+
"""Test the TabPFNExplainer class for regression problems."""
44+
import tabpfn
45+
46+
# setup
47+
model, data, labels, x_test = tabpfn_regression_problem
48+
x_explain = x_test[0]
49+
assert isinstance(model, tabpfn.TabPFNRegressor)
50+
if model.n_features_in_ == data.shape[1]:
51+
model.fit(data, labels)
52+
assert model.n_features_in_ == data.shape[1]
53+
54+
explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test)
55+
explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL)
56+
assert isinstance(explanation, InteractionValues)
57+
58+
# test that bare explainer gets turned into TabPFNExplainer
59+
explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test)
60+
assert isinstance(explainer, TabPFNExplainer)
61+
62+
# test that TabularExplainer works as well
63+
with pytest.warns(UserWarning):
64+
explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline")
65+
assert isinstance(explainer, TabularExplainer)
3866

3967

4068
@skip_if_no_tabpfn
4169
@pytest.mark.external_libraries
42-
def test_tabpfn_explainer_reg(tabpfn_regression_problem):
43-
"""Test the TabPFNExplainer class for regression problems."""
44-
import tabpfn
45-
46-
# setup
47-
model, data, labels, x_test = tabpfn_regression_problem
48-
x_explain = x_test[0]
49-
assert isinstance(model, tabpfn.TabPFNRegressor)
50-
if model.n_features_in_ == data.shape[1]:
51-
model.fit(data, labels)
52-
assert model.n_features_in_ == data.shape[1]
53-
54-
explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test)
55-
explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL)
56-
assert isinstance(explanation, InteractionValues)
57-
58-
# test that bare explainer gets turned into TabPFNExplainer
59-
explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test)
60-
assert isinstance(explainer, TabPFNExplainer)
61-
62-
# test that TabularExplainer works as well
63-
with pytest.warns(UserWarning):
64-
explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline")
65-
assert isinstance(explainer, TabularExplainer)
70+
class TestTabPFNExplainerBugFixes:
71+
"""Tests for bug fixes conducted in the TabPFNExplainer."""
72+
73+
def test_after_explanation_prediction(self, tabpfn_regression_problem):
74+
"""Tests that the model can be used for prediction after explanation.
75+
76+
This bug was raised in issue [#396](https://github.com/mmschlk/shapiq/issues/396)
77+
"""
78+
model, data, labels, x_test = tabpfn_regression_problem
79+
x_explain = x_test[0]
80+
81+
_ = model.predict(x_explain.reshape(1, -1))
82+
83+
explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test)
84+
explainer.explain(x=x_explain, budget=3)
85+
assert model.n_features_in_ == data.shape[1]
86+
87+
model.predict(x_explain.reshape(1, -1)) # should not raise an error

tests/tests_imputer/test_tabpfn_imputer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def test_tabpfn_imputer(tabpfn_classification_problem):
3636
imputer.fit(x=x_test[0])
3737

3838
# test the imputer
39-
imputer(np.asarray([True, True, True])) # 3 features should now been fitted
40-
assert model.n_features_in_ == 3
41-
imputer(np.asarray([True, True, False])) # 2 features should now been fitted
42-
assert model.n_features_in_ == 2
43-
imputer(np.asarray([False, True, False])) # 1 feature should now been fitted
44-
assert model.n_features_in_ == 1
39+
out_1 = imputer(np.asarray([True, True, True])) # 3 features should now been fitted
40+
out_2 = imputer(np.asarray([True, True, False])) # 2 features should now been fitted
41+
out_3 = imputer(np.asarray([False, True, False])) # 1 feature should now been fitted
42+
assert out_1 != out_2
43+
assert out_1 != out_3
44+
assert out_2 != out_3
4545

4646

4747
@skip_if_no_tabpfn

0 commit comments

Comments
 (0)