Skip to content

Commit 69a3cac

Browse files
Modularized computation of protoypes importance scores (#826)
* Modularized computation of protoypes importance scores from the visualization function. * Updated .gitignore. Changed notebook kernel. * Reverted .pymon and .gitignore. * Duplicate docstrings and reverted example. * Replaced prototypes_* to prototype_*. * Avoid duplication when plotting the prototypes * Updated example notebook. * Included missing return type.
1 parent 19fd682 commit 69a3cac

File tree

4 files changed

+162
-99
lines changed

4 files changed

+162
-99
lines changed

alibi/api/defaults.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@
278278
"""
279279

280280
DEFAULT_DATA_PROTOSELECT = {"prototypes": None,
281-
"prototypes_indices": None,
282-
"prototypes_labels": None} # type: dict
281+
"prototype_indices": None,
282+
"prototype_labels": None} # type: dict
283283
"""
284284
Default ProtoSelect data.
285285
"""

alibi/prototypes/protoselect.py

+105-42
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import logging
2-
import numpy as np
3-
import matplotlib.pyplot as plt
4-
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
5-
6-
from tqdm import tqdm
72
from copy import deepcopy
8-
from typing import Callable, Optional, Dict, List, Union, Tuple
3+
from typing import Callable, Dict, List, Optional, Tuple, Union
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
8+
from skimage.transform import resize
99
from sklearn.model_selection import KFold
1010
from sklearn.neighbors import KNeighborsClassifier
11-
from skimage.transform import resize
11+
from tqdm import tqdm
1212

13+
from alibi.api.defaults import (DEFAULT_DATA_PROTOSELECT,
14+
DEFAULT_META_PROTOSELECT)
15+
from alibi.api.interfaces import Explanation, FitMixin, Summariser
1316
from alibi.utils.distance import batch_compute_kernel_matrix
14-
from alibi.api.interfaces import Summariser, Explanation, FitMixin
15-
from alibi.api.defaults import DEFAULT_META_PROTOSELECT, DEFAULT_DATA_PROTOSELECT
1617
from alibi.utils.kernel import EuclideanDistance
1718

1819
logger = logging.getLogger(__name__)
@@ -226,10 +227,10 @@ def _build_summary(self, protos: Dict[int, List[int]]) -> Explanation:
226227
Helper method to build the summary as an `Explanation` object.
227228
"""
228229
data = deepcopy(DEFAULT_DATA_PROTOSELECT)
229-
data['prototypes_indices'] = np.concatenate(list(protos.values())).astype(np.int32)
230-
data['prototypes_labels'] = np.concatenate([[self.label_inv_mapping[l]] * len(protos[l])
231-
for l in protos]).astype(np.int32) # noqa: E741
232-
data['prototypes'] = self.Z[data['prototypes_indices']]
230+
data['prototype_indices'] = np.concatenate(list(protos.values())).astype(np.int32)
231+
data['prototype_labels'] = np.concatenate([[self.label_inv_mapping[l]] * len(protos[l])
232+
for l in protos]).astype(np.int32) # noqa: E741
233+
data['prototypes'] = self.Z[data['prototype_indices']]
233234
return Explanation(meta=self.meta, data=data)
234235

235236

@@ -262,7 +263,7 @@ def _helper_protoselect_euclidean_1knn(summariser: ProtoSelect,
262263
summary = summariser.summarise(num_prototypes=num_prototypes)
263264

264265
# train 1-knn classifier
265-
X_protos, y_protos = summary.data['prototypes'], summary.data['prototypes_labels']
266+
X_protos, y_protos = summary.data['prototypes'], summary.data['prototype_labels']
266267
if len(X_protos) == 0:
267268
return None
268269

@@ -546,6 +547,79 @@ def _imscatterplot(x: np.ndarray,
546547
return ax
547548

548549

550+
def compute_prototype_importances(summary: 'Explanation',
551+
trainset: Tuple[np.ndarray, np.ndarray],
552+
preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
553+
knn_kw: Optional[dict] = None) -> Dict[str, Optional[np.ndarray]]:
554+
555+
"""
556+
Computes the importance of each prototype. The importance of a prototype is the number of assigned
557+
training instances correctly classified according to the 1-KNN classifier
558+
(Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
559+
560+
Parameters
561+
----------
562+
summary
563+
An `Explanation` object produced by a call to the
564+
:py:meth:`alibi.prototypes.protoselect.ProtoSelect.summarise` method.
565+
trainset
566+
Tuple, `(X_train, y_train)`, consisting of the training data instances with the corresponding labels.
567+
preprocess_fn
568+
Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied.
569+
knn_kw
570+
Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
571+
set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
572+
If the `metric` is not specified, it will be set by default to ``'euclidean'``.
573+
See parameters description:
574+
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
575+
576+
Returns
577+
-------
578+
A dictionary containing:
579+
580+
- ``'prototype_indices'`` - an array of the prototype indices.
581+
582+
- ``'prototype_importances'`` - an array of prototype importances.
583+
584+
- ``'X_protos'`` - an array of raw prototypes.
585+
586+
- ``'X_protos_ft'`` - an optional array of preprocessed prototypes. If the ``preprocess_fn=None``, \
587+
no preprocessing is applied and ``None`` is returned instead.
588+
"""
589+
if knn_kw is None:
590+
knn_kw = {}
591+
592+
if knn_kw.get('metric') is None:
593+
knn_kw.update({'metric': 'euclidean'})
594+
logger.warning("KNN metric was not specified. Automatically setting `metric='euclidean'`.")
595+
596+
X_train, y_train = trainset
597+
X_protos = summary.data['prototypes']
598+
y_protos = summary.data['prototype_labels']
599+
600+
# preprocess the dataset
601+
X_train_ft = _batch_preprocessing(X=X_train, preprocess_fn=preprocess_fn) \
602+
if (preprocess_fn is not None) else X_train
603+
X_protos_ft = _batch_preprocessing(X=X_protos, preprocess_fn=preprocess_fn) \
604+
if (preprocess_fn is not None) else X_protos
605+
606+
# train knn classifier
607+
knn = KNeighborsClassifier(n_neighbors=1, **knn_kw)
608+
knn = knn.fit(X=X_protos_ft, y=y_protos)
609+
610+
# get neighbors indices for each training instance
611+
neigh_idx = knn.kneighbors(X=X_train_ft, n_neighbors=1, return_distance=False).reshape(-1)
612+
613+
# compute how many correct labeled instances each prototype covers
614+
idx, counts = np.unique(neigh_idx[y_protos[neigh_idx] == y_train], return_counts=True)
615+
return {
616+
'prototype_indices': idx,
617+
'prototype_importances': counts,
618+
'X_protos': X_protos[idx],
619+
'X_protos_ft': None if (preprocess_fn is None) else X_protos_ft[idx]
620+
}
621+
622+
549623
def visualize_image_prototypes(summary: 'Explanation',
550624
trainset: Tuple[np.ndarray, np.ndarray],
551625
reducer: Callable[[np.ndarray], np.ndarray],
@@ -560,7 +634,6 @@ def visualize_image_prototypes(summary: 'Explanation',
560634
Plot the images of the prototypes at the location given by the `reducer` representation.
561635
The size of each prototype is proportional to the logarithm of the number of assigned training instances correctly
562636
classified according to the 1-KNN classifier (Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
563-
564637
Parameters
565638
----------
566639
summary
@@ -573,7 +646,7 @@ def visualize_image_prototypes(summary: 'Explanation',
573646
input instances if ``preprocess_fn=None``. If the `preprocess_fn` is specified, the reducer will be called
574647
on the feature representation obtained after passing the input instances through the `preprocess_fn`.
575648
preprocess_fn
576-
Preprocessor function.
649+
Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied.
577650
knn_kw
578651
Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
579652
set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
@@ -592,37 +665,27 @@ def visualize_image_prototypes(summary: 'Explanation',
592665
zoom_ub
593666
Zoom upper bound. The zoom will be scaled linearly between `[zoom_lb, zoom_ub]`.
594667
"""
595-
if knn_kw is None:
596-
knn_kw = {}
597-
if knn_kw.get('metric') is None:
598-
knn_kw.update({'metric': 'euclidean'})
599-
logger.warning("KNN metric was not specified. Automatically setting `metric='euclidean'`.")
600-
601-
X_train, y_train = trainset
602-
X_protos = summary.data['prototypes']
603-
y_protos = summary.data['prototypes_labels']
604-
605-
# preprocess the dataset
606-
X_train_ft = _batch_preprocessing(X=X_train, preprocess_fn=preprocess_fn) \
607-
if (preprocess_fn is not None) else X_train
608-
X_protos_ft = _batch_preprocessing(X=X_protos, preprocess_fn=preprocess_fn) \
609-
if (preprocess_fn is not None) else X_protos
610-
611-
# train knn classifier
612-
knn = KNeighborsClassifier(n_neighbors=1, **knn_kw)
613-
knn = knn.fit(X=X_protos_ft, y=y_protos)
668+
# compute how many correct labeled instances each prototype covers
669+
protos_importance = compute_prototype_importances(summary=summary,
670+
trainset=trainset,
671+
preprocess_fn=preprocess_fn,
672+
knn_kw=knn_kw)
614673

615-
# get neighbors indices for each training instance
616-
neigh_idx = knn.kneighbors(X=X_train_ft, n_neighbors=1, return_distance=False).reshape(-1)
674+
# unpack values
675+
counts = protos_importance['prototype_importances']
676+
X_protos = protos_importance['X_protos']
677+
X_protos_ft = protos_importance['X_protos_ft'] if (protos_importance['X_protos_ft'] is not None) else X_protos
617678

618-
# compute how many correct labeled instances each prototype covers
619-
idx, counts = np.unique(neigh_idx[y_protos[neigh_idx] == y_train], return_counts=True)
620-
zoom = np.log(counts)
679+
# compute image zoom
680+
zoom = np.log(counts) # type: ignore[arg-type]
621681

622682
# compute 2D embedding
623-
protos_2d = reducer(X_protos_ft[idx])
683+
protos_2d = reducer(X_protos_ft) # type: ignore[arg-type]
624684
x, y = protos_2d[:, 0], protos_2d[:, 1]
625685

626686
# plot images
627-
return _imscatterplot(x=x, y=y, images=X_protos, ax=ax, fig_kw=fig_kw, image_size=image_size,
687+
return _imscatterplot(x=x, y=y,
688+
images=X_protos, # type: ignore[arg-type]
689+
ax=ax, fig_kw=fig_kw,
690+
image_size=image_size,
628691
zoom=zoom, zoom_lb=zoom_lb, zoom_ub=zoom_ub)

alibi/prototypes/tests/test_protoselect.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def test_protoselect(n_classes, ft_factor, kernel_distance, num_prototypes, eps)
3232
# get prototypes
3333
summary = summariser.summarise(num_prototypes=num_prototypes)
3434
protos = summary.prototypes
35-
protos_indices = summary.prototypes_indices
36-
protos_labels = summary.prototypes_labels
35+
protos_indices = summary.prototype_indices
36+
protos_labels = summary.prototype_labels
3737

3838
assert len(protos) == len(protos_indices) == len(protos_labels)
3939
assert len(protos) <= num_prototypes
@@ -130,7 +130,7 @@ def test_relabeling(n_samples, n_classes):
130130
assert np.array_equal(internal_labels, np.arange(len(provided_labels)))
131131

132132
# check if the prototypes labels are labels with the provided labels
133-
assert np.all(np.isin(np.unique(summary.data['prototypes_labels']), provided_labels))
133+
assert np.all(np.isin(np.unique(summary.data['prototype_labels']), provided_labels))
134134

135135

136136
def test_size_match():

0 commit comments

Comments
 (0)