Skip to content

Commit 4c28d19

Browse files
authored
Add NotFittedError and use it in AnchorTabular (#732)
* Add NotFittedError and use it in AnchorTabular * Formatting * Update explain method docstring * Make NotFittedError less generic * Update changelog
1 parent b58a170 commit 4c28d19

File tree

4 files changed

+42
-2
lines changed

4 files changed

+42
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
[Full Changelog](https://github.com/SeldonIO/alibi/compare/v0.7.0...v0.7.1)
55

66
## Added
7+
- New `exceptions.NotFittedError` exception which is raised whenever a compulsory call to a `fit` method has not been carried out. Specifically, this is now raised in `AnchorTabular.explain` when `AnchorTabular.fit` has been skipped ([#732](https://github.com/SeldonIO/alibi/pull/732)).
78

89
## Fixed
910

alibi/exceptions.py

+11
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,14 @@ class PredictorReturnTypeError(AlibiException, AlibiPredictorReturnTypeError):
4141
an unexpected or unsupported type.
4242
"""
4343
pass
44+
45+
46+
class NotFittedError(AlibiException):
47+
"""
48+
This exception is raised whenever a compulsory call to a `fit` method has not been carried out.
49+
"""
50+
51+
def __init__(self, object_name: str):
52+
super().__init__(
53+
f"This {object_name} instance is not fitted yet. Call 'fit' with appropriate arguments first."
54+
)

alibi/explainers/anchors/anchor_tabular.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from alibi.api.defaults import DEFAULT_DATA_ANCHOR, DEFAULT_META_ANCHOR
1010
from alibi.api.interfaces import Explainer, Explanation, FitMixin
11-
from alibi.exceptions import (PredictorCallError,
11+
from alibi.exceptions import (NotFittedError,
12+
PredictorCallError,
1213
PredictorReturnTypeError)
1314
from alibi.utils.discretizer import Discretizer
1415
from alibi.utils.mapping import ohe_to_ord, ord_to_ohe
@@ -647,6 +648,8 @@ def __init__(self,
647648
# update metadata
648649
self.meta['params'].update(seed=seed)
649650

651+
self._fitted = False
652+
650653
def fit(self, # type: ignore[override]
651654
train_data: np.ndarray,
652655
disc_perc: Tuple[Union[int, float], ...] = (25, 50, 75),
@@ -686,6 +689,8 @@ def fit(self, # type: ignore[override]
686689
# update metadata
687690
self.meta['params'].update(disc_perc=disc_perc)
688691

692+
self._fitted = True
693+
689694
return self
690695

691696
def _build_sampling_lookups(self, X: np.ndarray) -> None:
@@ -790,7 +795,15 @@ def explain(self,
790795
791796
.. _AnchorTabular examples:
792797
https://docs.seldon.io/projects/alibi/en/stable/methods/Anchors.html
798+
799+
Raises
800+
------
801+
:py:class:`alibi.exceptions.NotFittedError`
802+
If `fit` has not been called prior to calling `explain`.
793803
"""
804+
if not self._fitted:
805+
raise NotFittedError(self.meta["name"])
806+
794807
# transform one-hot encodings to labels if ohe == True
795808
X = ohe_to_ord(X_ohe=X.reshape(1, -1), cat_vars_ohe=self.cat_vars_ohe)[0].reshape(-1) if self.ohe else X
796809

alibi/explainers/tests/test_anchor_tabular.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytest_lazyfixture import lazy_fixture
77

88
from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR
9-
from alibi.exceptions import PredictorCallError, PredictorReturnTypeError
9+
from alibi.exceptions import NotFittedError, PredictorCallError, PredictorReturnTypeError
1010
from alibi.explainers import AnchorTabular, DistributedAnchorTabular
1111
from alibi.explainers.tests.utils import predict_fcn
1212

@@ -322,6 +322,13 @@ def bad_predictor(x: np.ndarray) -> list:
322322
return list(x)
323323

324324

325+
def good_predictor(x: np.ndarray) -> np.ndarray:
326+
"""
327+
A dummy predictor returning a vector of random binary target labels.
328+
"""
329+
return np.random.randint(low=0, high=2, size=x.shape[0])
330+
331+
325332
def test_anchor_tabular_fails_init_bad_feature_names_predictor_call():
326333
"""
327334
In this test `feature_names` is misspecified leading to an exception calling the `predictor`.
@@ -336,3 +343,11 @@ def test_anchor_tabular_fails_bad_predictor_return_type():
336343
"""
337344
with pytest.raises(PredictorReturnTypeError):
338345
explainer = AnchorTabular(bad_predictor, feature_names=['f1', 'f2', 'f3']) # noqa: F841
346+
347+
348+
def test_anchor_tabular_explain_fails_not_fitted():
349+
explainer = AnchorTabular(good_predictor, feature_names=['f1', 'f2'])
350+
with pytest.raises(NotFittedError) as err:
351+
explainer.explain(np.ones(2))
352+
expected_msg = "This AnchorTabular instance is not fitted yet. Call 'fit' with appropriate arguments first."
353+
assert str(err.value) == expected_msg

0 commit comments

Comments
 (0)