@@ -550,7 +550,7 @@ def _imscatterplot(x: np.ndarray,
550
550
def compute_prototype_importances (summary : 'Explanation' ,
551
551
trainset : Tuple [np .ndarray , np .ndarray ],
552
552
preprocess_fn : Optional [Callable [[np .ndarray ], np .ndarray ]] = None ,
553
- knn_kw : Optional [dict ] = None ):
553
+ knn_kw : Optional [dict ] = None ) -> Dict [ str , Optional [ np . ndarray ]] :
554
554
555
555
"""
556
556
Computes the importance of each prototype. The importance of a prototype is the number of assigned
@@ -677,12 +677,15 @@ def visualize_image_prototypes(summary: 'Explanation',
677
677
X_protos_ft = protos_importance ['X_protos_ft' ] if (protos_importance ['X_protos_ft' ] is not None ) else X_protos
678
678
679
679
# compute image zoom
680
- zoom = np .log (counts )
680
+ zoom = np .log (counts ) # type: ignore[arg-type]
681
681
682
682
# compute 2D embedding
683
- protos_2d = reducer (X_protos_ft )
683
+ protos_2d = reducer (X_protos_ft ) # type: ignore[arg-type]
684
684
x , y = protos_2d [:, 0 ], protos_2d [:, 1 ]
685
685
686
686
# plot images
687
- return _imscatterplot (x = x , y = y , images = X_protos , ax = ax , fig_kw = fig_kw , image_size = image_size ,
687
+ return _imscatterplot (x = x , y = y ,
688
+ images = X_protos , # type: ignore[arg-type]
689
+ ax = ax , fig_kw = fig_kw ,
690
+ image_size = image_size ,
688
691
zoom = zoom , zoom_lb = zoom_lb , zoom_ub = zoom_ub )
0 commit comments