1
1
import logging
2
- import numpy as np
3
- import matplotlib .pyplot as plt
4
- from matplotlib .offsetbox import OffsetImage , AnnotationBbox
5
-
6
- from tqdm import tqdm
7
2
from copy import deepcopy
8
- from typing import Callable , Optional , Dict , List , Union , Tuple
3
+ from typing import Callable , Dict , List , Optional , Tuple , Union
4
+
5
+ import matplotlib .pyplot as plt
6
+ import numpy as np
7
+ from matplotlib .offsetbox import AnnotationBbox , OffsetImage
8
+ from skimage .transform import resize
9
9
from sklearn .model_selection import KFold
10
10
from sklearn .neighbors import KNeighborsClassifier
11
- from skimage . transform import resize
11
+ from tqdm import tqdm
12
12
13
+ from alibi .api .defaults import (DEFAULT_DATA_PROTOSELECT ,
14
+ DEFAULT_META_PROTOSELECT )
15
+ from alibi .api .interfaces import Explanation , FitMixin , Summariser
13
16
from alibi .utils .distance import batch_compute_kernel_matrix
14
- from alibi .api .interfaces import Summariser , Explanation , FitMixin
15
- from alibi .api .defaults import DEFAULT_META_PROTOSELECT , DEFAULT_DATA_PROTOSELECT
16
17
from alibi .utils .kernel import EuclideanDistance
17
18
18
19
logger = logging .getLogger (__name__ )
@@ -226,10 +227,10 @@ def _build_summary(self, protos: Dict[int, List[int]]) -> Explanation:
226
227
Helper method to build the summary as an `Explanation` object.
227
228
"""
228
229
data = deepcopy (DEFAULT_DATA_PROTOSELECT )
229
- data ['prototypes_indices ' ] = np .concatenate (list (protos .values ())).astype (np .int32 )
230
- data ['prototypes_labels ' ] = np .concatenate ([[self .label_inv_mapping [l ]] * len (protos [l ])
231
- for l in protos ]).astype (np .int32 ) # noqa: E741
232
- data ['prototypes' ] = self .Z [data ['prototypes_indices ' ]]
230
+ data ['prototype_indices ' ] = np .concatenate (list (protos .values ())).astype (np .int32 )
231
+ data ['prototype_labels ' ] = np .concatenate ([[self .label_inv_mapping [l ]] * len (protos [l ])
232
+ for l in protos ]).astype (np .int32 ) # noqa: E741
233
+ data ['prototypes' ] = self .Z [data ['prototype_indices ' ]]
233
234
return Explanation (meta = self .meta , data = data )
234
235
235
236
@@ -262,7 +263,7 @@ def _helper_protoselect_euclidean_1knn(summariser: ProtoSelect,
262
263
summary = summariser .summarise (num_prototypes = num_prototypes )
263
264
264
265
# train 1-knn classifier
265
- X_protos , y_protos = summary .data ['prototypes' ], summary .data ['prototypes_labels ' ]
266
+ X_protos , y_protos = summary .data ['prototypes' ], summary .data ['prototype_labels ' ]
266
267
if len (X_protos ) == 0 :
267
268
return None
268
269
@@ -546,6 +547,79 @@ def _imscatterplot(x: np.ndarray,
546
547
return ax
547
548
548
549
550
+ def compute_prototype_importances (summary : 'Explanation' ,
551
+ trainset : Tuple [np .ndarray , np .ndarray ],
552
+ preprocess_fn : Optional [Callable [[np .ndarray ], np .ndarray ]] = None ,
553
+ knn_kw : Optional [dict ] = None ) -> Dict [str , Optional [np .ndarray ]]:
554
+
555
+ """
556
+ Computes the importance of each prototype. The importance of a prototype is the number of assigned
557
+ training instances correctly classified according to the 1-KNN classifier
558
+ (Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
559
+
560
+ Parameters
561
+ ----------
562
+ summary
563
+ An `Explanation` object produced by a call to the
564
+ :py:meth:`alibi.prototypes.protoselect.ProtoSelect.summarise` method.
565
+ trainset
566
+ Tuple, `(X_train, y_train)`, consisting of the training data instances with the corresponding labels.
567
+ preprocess_fn
568
+ Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied.
569
+ knn_kw
570
+ Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
571
+ set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
572
+ If the `metric` is not specified, it will be set by default to ``'euclidean'``.
573
+ See parameters description:
574
+ https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
575
+
576
+ Returns
577
+ -------
578
+ A dictionary containing:
579
+
580
+ - ``'prototype_indices'`` - an array of the prototype indices.
581
+
582
+ - ``'prototype_importances'`` - an array of prototype importances.
583
+
584
+ - ``'X_protos'`` - an array of raw prototypes.
585
+
586
+ - ``'X_protos_ft'`` - an optional array of preprocessed prototypes. If the ``preprocess_fn=None``, \
587
+ no preprocessing is applied and ``None`` is returned instead.
588
+ """
589
+ if knn_kw is None :
590
+ knn_kw = {}
591
+
592
+ if knn_kw .get ('metric' ) is None :
593
+ knn_kw .update ({'metric' : 'euclidean' })
594
+ logger .warning ("KNN metric was not specified. Automatically setting `metric='euclidean'`." )
595
+
596
+ X_train , y_train = trainset
597
+ X_protos = summary .data ['prototypes' ]
598
+ y_protos = summary .data ['prototype_labels' ]
599
+
600
+ # preprocess the dataset
601
+ X_train_ft = _batch_preprocessing (X = X_train , preprocess_fn = preprocess_fn ) \
602
+ if (preprocess_fn is not None ) else X_train
603
+ X_protos_ft = _batch_preprocessing (X = X_protos , preprocess_fn = preprocess_fn ) \
604
+ if (preprocess_fn is not None ) else X_protos
605
+
606
+ # train knn classifier
607
+ knn = KNeighborsClassifier (n_neighbors = 1 , ** knn_kw )
608
+ knn = knn .fit (X = X_protos_ft , y = y_protos )
609
+
610
+ # get neighbors indices for each training instance
611
+ neigh_idx = knn .kneighbors (X = X_train_ft , n_neighbors = 1 , return_distance = False ).reshape (- 1 )
612
+
613
+ # compute how many correct labeled instances each prototype covers
614
+ idx , counts = np .unique (neigh_idx [y_protos [neigh_idx ] == y_train ], return_counts = True )
615
+ return {
616
+ 'prototype_indices' : idx ,
617
+ 'prototype_importances' : counts ,
618
+ 'X_protos' : X_protos [idx ],
619
+ 'X_protos_ft' : None if (preprocess_fn is None ) else X_protos_ft [idx ]
620
+ }
621
+
622
+
549
623
def visualize_image_prototypes (summary : 'Explanation' ,
550
624
trainset : Tuple [np .ndarray , np .ndarray ],
551
625
reducer : Callable [[np .ndarray ], np .ndarray ],
@@ -560,7 +634,6 @@ def visualize_image_prototypes(summary: 'Explanation',
560
634
Plot the images of the prototypes at the location given by the `reducer` representation.
561
635
The size of each prototype is proportional to the logarithm of the number of assigned training instances correctly
562
636
classified according to the 1-KNN classifier (Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
563
-
564
637
Parameters
565
638
----------
566
639
summary
@@ -573,7 +646,7 @@ def visualize_image_prototypes(summary: 'Explanation',
573
646
input instances if ``preprocess_fn=None``. If the `preprocess_fn` is specified, the reducer will be called
574
647
on the feature representation obtained after passing the input instances through the `preprocess_fn`.
575
648
preprocess_fn
576
- Preprocessor function.
649
+ Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied .
577
650
knn_kw
578
651
Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
579
652
set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
@@ -592,37 +665,27 @@ def visualize_image_prototypes(summary: 'Explanation',
592
665
zoom_ub
593
666
Zoom upper bound. The zoom will be scaled linearly between `[zoom_lb, zoom_ub]`.
594
667
"""
595
- if knn_kw is None :
596
- knn_kw = {}
597
- if knn_kw .get ('metric' ) is None :
598
- knn_kw .update ({'metric' : 'euclidean' })
599
- logger .warning ("KNN metric was not specified. Automatically setting `metric='euclidean'`." )
600
-
601
- X_train , y_train = trainset
602
- X_protos = summary .data ['prototypes' ]
603
- y_protos = summary .data ['prototypes_labels' ]
604
-
605
- # preprocess the dataset
606
- X_train_ft = _batch_preprocessing (X = X_train , preprocess_fn = preprocess_fn ) \
607
- if (preprocess_fn is not None ) else X_train
608
- X_protos_ft = _batch_preprocessing (X = X_protos , preprocess_fn = preprocess_fn ) \
609
- if (preprocess_fn is not None ) else X_protos
610
-
611
- # train knn classifier
612
- knn = KNeighborsClassifier (n_neighbors = 1 , ** knn_kw )
613
- knn = knn .fit (X = X_protos_ft , y = y_protos )
668
+ # compute how many correct labeled instances each prototype covers
669
+ protos_importance = compute_prototype_importances (summary = summary ,
670
+ trainset = trainset ,
671
+ preprocess_fn = preprocess_fn ,
672
+ knn_kw = knn_kw )
614
673
615
- # get neighbors indices for each training instance
616
- neigh_idx = knn .kneighbors (X = X_train_ft , n_neighbors = 1 , return_distance = False ).reshape (- 1 )
674
+ # unpack values
675
+ counts = protos_importance ['prototype_importances' ]
676
+ X_protos = protos_importance ['X_protos' ]
677
+ X_protos_ft = protos_importance ['X_protos_ft' ] if (protos_importance ['X_protos_ft' ] is not None ) else X_protos
617
678
618
- # compute how many correct labeled instances each prototype covers
619
- idx , counts = np .unique (neigh_idx [y_protos [neigh_idx ] == y_train ], return_counts = True )
620
- zoom = np .log (counts )
679
+ # compute image zoom
680
+ zoom = np .log (counts ) # type: ignore[arg-type]
621
681
622
682
# compute 2D embedding
623
- protos_2d = reducer (X_protos_ft [ idx ])
683
+ protos_2d = reducer (X_protos_ft ) # type: ignore[arg-type]
624
684
x , y = protos_2d [:, 0 ], protos_2d [:, 1 ]
625
685
626
686
# plot images
627
- return _imscatterplot (x = x , y = y , images = X_protos , ax = ax , fig_kw = fig_kw , image_size = image_size ,
687
+ return _imscatterplot (x = x , y = y ,
688
+ images = X_protos , # type: ignore[arg-type]
689
+ ax = ax , fig_kw = fig_kw ,
690
+ image_size = image_size ,
628
691
zoom = zoom , zoom_lb = zoom_lb , zoom_ub = zoom_ub )
0 commit comments