1
- # -*- coding: utf-8 -*-
2
- from __future__ import absolute_import
3
- from typing import Union , Optional , Callable , Tuple , List , TYPE_CHECKING
1
+ from typing import Union , Optional , Callable , TYPE_CHECKING
4
2
if TYPE_CHECKING :
5
3
import PIL
6
4
7
5
import numpy as np
8
- import keras
9
- import keras .backend as K
10
6
from keras .models import Model
11
7
from keras .layers import Layer
12
8
from keras .layers import (
30
26
31
27
# note that keras.models.Sequential subclasses keras.models.Model
32
28
@explain_prediction .register (Model )
33
- def explain_prediction_keras (model , # type : Model
34
- doc , # type : np.ndarray
35
- targets = None , # type : Optional[list]
36
- layer = None , # type : Optional[Union[int, str, Layer]]
29
+ def explain_prediction_keras (model : Model ,
30
+ doc : np .ndarray ,
31
+ targets : Optional [list ] = None ,
32
+ layer : Optional [Union [int , str , Layer ]] = None ,
37
33
image = None ,
38
- ):
39
- # type: (...) -> Explanation
34
+ ) -> Explanation :
40
35
"""
41
36
Explain the prediction of a Keras classifier with the Grad-CAM technique.
42
37
@@ -133,7 +128,7 @@ def explain_prediction_keras_not_supported(model, doc):
133
128
134
129
def explain_prediction_keras_image (model ,
135
130
doc ,
136
- image = None , # type : Optional['PIL.Image.Image']
131
+ image : Optional ['PIL.Image.Image' ] = None ,
137
132
targets = None ,
138
133
layer = None ,
139
134
):
@@ -204,23 +199,20 @@ def explain_prediction_keras_image(model,
204
199
)
205
200
206
201
207
- def _maybe_image (model , doc ):
208
- # type: (Model, np.ndarray) -> bool
202
+ def _maybe_image (model : Model , doc : np .ndarray ) -> bool :
209
203
"""Decide whether we are dealing with a image-based explanation
210
204
based on heuristics on ``model`` and ``doc``."""
211
205
return _maybe_image_input (doc ) and _maybe_image_model (model )
212
206
213
207
214
- def _maybe_image_input (doc ):
215
- # type: (np.ndarray) -> bool
208
+ def _maybe_image_input (doc : np .ndarray ) -> bool :
216
209
"""Decide whether ``doc`` represents an image input."""
217
210
rank = len (doc .shape )
218
211
# image with channels or without (spatial only)
219
212
return rank == 4 or rank == 3
220
213
221
214
222
- def _maybe_image_model (model ):
223
- # type: (Model) -> bool
215
+ def _maybe_image_model (model : Model ) -> bool :
224
216
"""Decide whether ``model`` is used for images."""
225
217
# FIXME: replace try-except with something else
226
218
try :
@@ -239,22 +231,19 @@ def _maybe_image_model(model):
239
231
)
240
232
241
233
242
- def _is_possible_image_model_layer (model , layer ):
243
- # type: (Model, Layer) -> bool
234
+ def _is_possible_image_model_layer (model : Model , layer : Layer ) -> bool :
244
235
"""Check that the given ``layer`` is usually used for images."""
245
236
return isinstance (layer , image_model_layers )
246
237
247
238
248
- def _extract_image (doc ):
249
- # type: (np.ndarray) -> 'PIL.Image.Image'
239
+ def _extract_image (doc : np .ndarray ) -> 'PIL.Image.Image' :
250
240
"""Convert ``doc`` tensor to image."""
251
241
im_arr , = doc # rank 4 batch -> rank 3 single image
252
242
image = array_to_img (im_arr )
253
243
return image
254
244
255
245
256
- def _validate_doc (model , doc ):
257
- # type: (Model, np.ndarray) -> None
246
+ def _validate_doc (model : Model , doc : np .ndarray ) -> None :
258
247
"""
259
248
Check that the input ``doc`` is suitable for ``model``.
260
249
"""
@@ -277,8 +266,7 @@ def _validate_doc(model, doc):
277
266
'input: {}, doc: {}' .format (input_sh , doc_sh ))
278
267
279
268
280
- def _get_activation_layer (model , layer ):
281
- # type: (Model, Union[None, int, str, Layer]) -> Layer
269
+ def _get_activation_layer (model : Model , layer : Union [None , int , str , Layer ]) -> Layer :
282
270
"""
283
271
Get an instance of the desired activation layer in ``model``,
284
272
as specified by ``layer``.
@@ -306,8 +294,7 @@ def _get_activation_layer(model, layer):
306
294
raise ValueError ('Can not perform Grad-CAM on the retrieved activation layer' )
307
295
308
296
309
- def _search_layer_backwards (model , condition ):
310
- # type: (Model, Callable[[Model, Layer], bool]) -> Layer
297
+ def _search_layer_backwards (model : Model , condition : Callable [[Model , Layer ], bool ]) -> Layer :
311
298
"""
312
299
Search for a layer in ``model``, backwards (starting from the output layer),
313
300
checking if the layer is suitable with the callable ``condition``,
@@ -321,8 +308,7 @@ def _search_layer_backwards(model, condition):
321
308
raise ValueError ('Could not find a suitable target layer automatically.' )
322
309
323
310
324
- def _is_suitable_activation_layer (model , layer ):
325
- # type: (Model, Layer) -> bool
311
+ def _is_suitable_activation_layer (model : Model , layer : Layer ) -> bool :
326
312
"""
327
313
Check whether the layer ``layer`` matches what is required
328
314
by ``model`` to do Grad-CAM on ``layer``.
@@ -337,6 +323,8 @@ def _is_suitable_activation_layer(model, layer):
337
323
# check layer name
338
324
339
325
# a check that asks "can we resize this activation layer over the image?"
340
- rank = len (layer .output_shape )
326
+ # Use the tensor shape of the layer's output
327
+ output_shape = layer .output .shape
328
+ rank = len (output_shape )
341
329
required_rank = len (model .input_shape )
342
330
return rank == required_rank
0 commit comments