Skip to content

Commit bb43635

Browse files
Included missing return type.
1 parent 907739a commit bb43635

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

alibi/prototypes/protoselect.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def _imscatterplot(x: np.ndarray,
550550
def compute_prototype_importances(summary: 'Explanation',
551551
trainset: Tuple[np.ndarray, np.ndarray],
552552
preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
553-
knn_kw: Optional[dict] = None):
553+
knn_kw: Optional[dict] = None) -> Dict[str, Optional[np.ndarray]]:
554554

555555
"""
556556
Computes the importance of each prototype. The importance of a prototype is the number of assigned
@@ -677,12 +677,15 @@ def visualize_image_prototypes(summary: 'Explanation',
677677
X_protos_ft = protos_importance['X_protos_ft'] if (protos_importance['X_protos_ft'] is not None) else X_protos
678678

679679
# compute image zoom
680-
zoom = np.log(counts)
680+
zoom = np.log(counts) # type: ignore[arg-type]
681681

682682
# compute 2D embedding
683-
protos_2d = reducer(X_protos_ft)
683+
protos_2d = reducer(X_protos_ft) # type: ignore[arg-type]
684684
x, y = protos_2d[:, 0], protos_2d[:, 1]
685685

686686
# plot images
687-
return _imscatterplot(x=x, y=y, images=X_protos, ax=ax, fig_kw=fig_kw, image_size=image_size,
687+
return _imscatterplot(x=x, y=y,
688+
images=X_protos, # type: ignore[arg-type]
689+
ax=ax, fig_kw=fig_kw,
690+
image_size=image_size,
688691
zoom=zoom, zoom_lb=zoom_lb, zoom_ub=zoom_ub)

0 commit comments

Comments
 (0)