|
11 | 11 |
|
12 | 12 | @skip_if_no_tabpfn
|
13 | 13 | @pytest.mark.external_libraries
|
14 |
| -def test_tabpfn_explainer_clf(tabpfn_classification_problem): |
15 |
| - """Test the TabPFNExplainer class for classification problems.""" |
16 |
| - import tabpfn |
| 14 | +class TestTabPFNExplainer: |
| 15 | + """Tests for the TabPFNExplainer class.""" |
17 | 16 |
|
18 |
| - # setup |
19 |
| - model, data, labels, x_test = tabpfn_classification_problem |
20 |
| - x_explain = x_test[0] |
21 |
| - assert isinstance(model, tabpfn.TabPFNClassifier) |
22 |
| - if model.n_features_in_ == data.shape[1]: |
23 |
| - model.fit(data, labels) |
24 |
| - assert model.n_features_in_ == data.shape[1] |
| 17 | + def test_tabpfn_explainer_clf(self, tabpfn_classification_problem): |
| 18 | + """Test the TabPFNExplainer class for classification problems.""" |
| 19 | + import tabpfn |
25 | 20 |
|
26 |
| - explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test) |
27 |
| - explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL) |
28 |
| - assert isinstance(explanation, InteractionValues) |
| 21 | + # setup |
| 22 | + model, data, labels, x_test = tabpfn_classification_problem |
| 23 | + x_explain = x_test[0] |
| 24 | + assert isinstance(model, tabpfn.TabPFNClassifier) |
| 25 | + if model.n_features_in_ == data.shape[1]: |
| 26 | + model.fit(data, labels) |
| 27 | + assert model.n_features_in_ == data.shape[1] |
29 | 28 |
|
30 |
| - # test that bare explainer gets turned into TabPFNExplainer |
31 |
| - explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test) |
32 |
| - assert isinstance(explainer, TabPFNExplainer) |
| 29 | + explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test) |
| 30 | + explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL) |
| 31 | + assert isinstance(explanation, InteractionValues) |
33 | 32 |
|
34 |
| - # test that TabularExplainer works as well |
35 |
| - with pytest.warns(UserWarning): |
36 |
| - explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline") |
37 |
| - assert isinstance(explainer, TabularExplainer) |
| 33 | + # test that bare explainer gets turned into TabPFNExplainer |
| 34 | + explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test) |
| 35 | + assert isinstance(explainer, TabPFNExplainer) |
| 36 | + |
| 37 | + # test that TabularExplainer works as well |
| 38 | + with pytest.warns(UserWarning): |
| 39 | + explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline") |
| 40 | + assert isinstance(explainer, TabularExplainer) |
| 41 | + |
| 42 | + def test_tabpfn_explainer_reg(self, tabpfn_regression_problem): |
| 43 | + """Test the TabPFNExplainer class for regression problems.""" |
| 44 | + import tabpfn |
| 45 | + |
| 46 | + # setup |
| 47 | + model, data, labels, x_test = tabpfn_regression_problem |
| 48 | + x_explain = x_test[0] |
| 49 | + assert isinstance(model, tabpfn.TabPFNRegressor) |
| 50 | + if model.n_features_in_ == data.shape[1]: |
| 51 | + model.fit(data, labels) |
| 52 | + assert model.n_features_in_ == data.shape[1] |
| 53 | + |
| 54 | + explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test) |
| 55 | + explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL) |
| 56 | + assert isinstance(explanation, InteractionValues) |
| 57 | + |
| 58 | + # test that bare explainer gets turned into TabPFNExplainer |
| 59 | + explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test) |
| 60 | + assert isinstance(explainer, TabPFNExplainer) |
| 61 | + |
| 62 | + # test that TabularExplainer works as well |
| 63 | + with pytest.warns(UserWarning): |
| 64 | + explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline") |
| 65 | + assert isinstance(explainer, TabularExplainer) |
38 | 66 |
|
39 | 67 |
|
40 | 68 | @skip_if_no_tabpfn
|
41 | 69 | @pytest.mark.external_libraries
|
42 |
| -def test_tabpfn_explainer_reg(tabpfn_regression_problem): |
43 |
| - """Test the TabPFNExplainer class for regression problems.""" |
44 |
| - import tabpfn |
45 |
| - |
46 |
| - # setup |
47 |
| - model, data, labels, x_test = tabpfn_regression_problem |
48 |
| - x_explain = x_test[0] |
49 |
| - assert isinstance(model, tabpfn.TabPFNRegressor) |
50 |
| - if model.n_features_in_ == data.shape[1]: |
51 |
| - model.fit(data, labels) |
52 |
| - assert model.n_features_in_ == data.shape[1] |
53 |
| - |
54 |
| - explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test) |
55 |
| - explanation = explainer.explain(x=x_explain, budget=BUDGET_NR_FEATURES_SMALL) |
56 |
| - assert isinstance(explanation, InteractionValues) |
57 |
| - |
58 |
| - # test that bare explainer gets turned into TabPFNExplainer |
59 |
| - explainer = Explainer(model=model, data=data, labels=labels, x_test=x_test) |
60 |
| - assert isinstance(explainer, TabPFNExplainer) |
61 |
| - |
62 |
| - # test that TabularExplainer works as well |
63 |
| - with pytest.warns(UserWarning): |
64 |
| - explainer = TabularExplainer(model=model, data=data, class_index=1, imputer="baseline") |
65 |
| - assert isinstance(explainer, TabularExplainer) |
| 70 | +class TestTabPFNExplainerBugFixes: |
| 71 | + """Tests for bug fixes conducted in the TabPFNExplainer.""" |
| 72 | + |
| 73 | + def test_after_explanation_prediction(self, tabpfn_regression_problem): |
| 74 | + """Tests that the model can be used for prediction after explanation. |
| 75 | +
|
| 76 | + This bug was raised in issue [#396](https://github.com/mmschlk/shapiq/issues/396) |
| 77 | + """ |
| 78 | + model, data, labels, x_test = tabpfn_regression_problem |
| 79 | + x_explain = x_test[0] |
| 80 | + |
| 81 | + _ = model.predict(x_explain.reshape(1, -1)) |
| 82 | + |
| 83 | + explainer = TabPFNExplainer(model=model, data=data, labels=labels, x_test=x_test) |
| 84 | + explainer.explain(x=x_explain, budget=3) |
| 85 | + assert model.n_features_in_ == data.shape[1] |
| 86 | + |
| 87 | + model.predict(x_explain.reshape(1, -1)) # should not raise an error |
0 commit comments