Skip to content

Commit bb9a2d7

Browse files
authored
Merge pull request #58 from eli5-org/modern-keras-try2
Support modern Keras 3.x
2 parents 2bf2f39 + 3cf2e42 commit bb9a2d7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+610
-464
lines changed

CHANGES.rst

Lines changed: 6 additions & 0 deletions

README.rst

Lines changed: 6 additions & 0 deletions

constraints-test.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

docs/source/_notebooks/keras-image-classifiers.rst

Lines changed: 179 additions & 114 deletions
Binary file not shown.
Binary file not shown.

docs/source/libraries/keras.rst

Lines changed: 1 addition & 3 deletions

docs/update-notebooks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ mv ../notebooks/keras-image-classifiers.rst \
7777
rm -r source/_notebooks/keras-image-classifiers_files
7878
mv ../notebooks/keras-image-classifiers_files/ \
7979
source/_notebooks/
80-
sed -i 's&.. image:: keras-image-classifiers_files/&.. image:: ../_notebooks/keras-image-classifiers_files/&g' \
80+
sed -i '' 's/image:: keras-image-classifiers_files/image:: ..\/_notebooks\/keras-image-classifiers_files/g' \
8181
source/_notebooks/keras-image-classifiers.rst
8282

8383

eli5/keras/explain_prediction.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
# -*- coding: utf-8 -*-
2-
from __future__ import absolute_import
3-
from typing import Union, Optional, Callable, Tuple, List, TYPE_CHECKING
1+
from typing import Union, Optional, Callable, TYPE_CHECKING
42
if TYPE_CHECKING:
53
import PIL
64

75
import numpy as np
8-
import keras
9-
import keras.backend as K
106
from keras.models import Model
117
from keras.layers import Layer
128
from keras.layers import (
@@ -30,13 +26,12 @@
3026

3127
# note that keras.models.Sequential subclasses keras.models.Model
3228
@explain_prediction.register(Model)
33-
def explain_prediction_keras(model, # type: Model
34-
doc, # type: np.ndarray
35-
targets=None, # type: Optional[list]
36-
layer=None, # type: Optional[Union[int, str, Layer]]
29+
def explain_prediction_keras(model: Model,
30+
doc: np.ndarray,
31+
targets: Optional[list] = None,
32+
layer: Optional[Union[int, str, Layer]] = None,
3733
image=None,
38-
):
39-
# type: (...) -> Explanation
34+
) -> Explanation:
4035
"""
4136
Explain the prediction of a Keras classifier with the Grad-CAM technique.
4237
@@ -133,7 +128,7 @@ def explain_prediction_keras_not_supported(model, doc):
133128

134129
def explain_prediction_keras_image(model,
135130
doc,
136-
image=None, # type: Optional['PIL.Image.Image']
131+
image: Optional['PIL.Image.Image'] = None,
137132
targets=None,
138133
layer=None,
139134
):
@@ -204,23 +199,20 @@ def explain_prediction_keras_image(model,
204199
)
205200

206201

207-
def _maybe_image(model, doc):
208-
# type: (Model, np.ndarray) -> bool
202+
def _maybe_image(model: Model, doc: np.ndarray) -> bool:
209203
"""Decide whether we are dealing with a image-based explanation
210204
based on heuristics on ``model`` and ``doc``."""
211205
return _maybe_image_input(doc) and _maybe_image_model(model)
212206

213207

214-
def _maybe_image_input(doc):
215-
# type: (np.ndarray) -> bool
208+
def _maybe_image_input(doc: np.ndarray) -> bool:
216209
"""Decide whether ``doc`` represents an image input."""
217210
rank = len(doc.shape)
218211
# image with channels or without (spatial only)
219212
return rank == 4 or rank == 3
220213

221214

222-
def _maybe_image_model(model):
223-
# type: (Model) -> bool
215+
def _maybe_image_model(model: Model) -> bool:
224216
"""Decide whether ``model`` is used for images."""
225217
# FIXME: replace try-except with something else
226218
try:
@@ -239,22 +231,19 @@ def _maybe_image_model(model):
239231
)
240232

241233

242-
def _is_possible_image_model_layer(model, layer):
243-
# type: (Model, Layer) -> bool
234+
def _is_possible_image_model_layer(model: Model, layer: Layer) -> bool:
244235
"""Check that the given ``layer`` is usually used for images."""
245236
return isinstance(layer, image_model_layers)
246237

247238

248-
def _extract_image(doc):
249-
# type: (np.ndarray) -> 'PIL.Image.Image'
239+
def _extract_image(doc: np.ndarray) -> 'PIL.Image.Image':
250240
"""Convert ``doc`` tensor to image."""
251241
im_arr, = doc # rank 4 batch -> rank 3 single image
252242
image = array_to_img(im_arr)
253243
return image
254244

255245

256-
def _validate_doc(model, doc):
257-
# type: (Model, np.ndarray) -> None
246+
def _validate_doc(model: Model, doc: np.ndarray) -> None:
258247
"""
259248
Check that the input ``doc`` is suitable for ``model``.
260249
"""
@@ -277,8 +266,7 @@ def _validate_doc(model, doc):
277266
'input: {}, doc: {}'.format(input_sh, doc_sh))
278267

279268

280-
def _get_activation_layer(model, layer):
281-
# type: (Model, Union[None, int, str, Layer]) -> Layer
269+
def _get_activation_layer(model: Model, layer: Union[None, int, str, Layer]) -> Layer:
282270
"""
283271
Get an instance of the desired activation layer in ``model``,
284272
as specified by ``layer``.
@@ -306,8 +294,7 @@ def _get_activation_layer(model, layer):
306294
raise ValueError('Can not perform Grad-CAM on the retrieved activation layer')
307295

308296

309-
def _search_layer_backwards(model, condition):
310-
# type: (Model, Callable[[Model, Layer], bool]) -> Layer
297+
def _search_layer_backwards(model: Model, condition: Callable[[Model, Layer], bool]) -> Layer:
311298
"""
312299
Search for a layer in ``model``, backwards (starting from the output layer),
313300
checking if the layer is suitable with the callable ``condition``,
@@ -321,8 +308,7 @@ def _search_layer_backwards(model, condition):
321308
raise ValueError('Could not find a suitable target layer automatically.')
322309

323310

324-
def _is_suitable_activation_layer(model, layer):
325-
# type: (Model, Layer) -> bool
311+
def _is_suitable_activation_layer(model: Model, layer: Layer) -> bool:
326312
"""
327313
Check whether the layer ``layer`` matches what is required
328314
by ``model`` to do Grad-CAM on ``layer``.
@@ -337,6 +323,8 @@ def _is_suitable_activation_layer(model, layer):
337323
# check layer name
338324

339325
# a check that asks "can we resize this activation layer over the image?"
340-
rank = len(layer.output_shape)
326+
# Use the tensor shape of the layer's output
327+
output_shape = layer.output.shape
328+
rank = len(output_shape)
341329
required_rank = len(model.input_shape)
342330
return rank == required_rank

0 commit comments

Comments
 (0)