|
43 | 43 | """
|
44 | 44 | from tensorboard.compat.proto.graph_pb2 import GraphDef
|
45 | 45 | from tensorboard.compat.tensorflow_stub import dtypes
|
| 46 | +from tensorboard.util import tb_logging |
| 47 | + |
| 48 | + |
| 49 | +logger = tb_logging.get_logger() |
46 | 50 |
|
47 | 51 |
|
48 | 52 | def _walk_layers(keras_layer):
|
@@ -259,19 +263,34 @@ def keras_model_to_graph_def(keras_layer):
|
259 | 263 |
|
260 | 264 | dtype_or_policy = layer_config.get("dtype")
|
261 | 265 | dtype = None
|
| 266 | + has_unsupported_value = False |
262 | 267 | # If this is a dict, try and extract the dtype string from
|
263 |
| - # `config.name`; keras will export like this for non-input layers. If |
264 |
| - # we can't find `config.name`, we skip it as it's presumably a instance |
265 |
| - # of tf/keras/mixed_precision/Policy rather than a single dtype. |
266 |
| - # TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype. |
267 |
| - if isinstance(dtype_or_policy, dict): |
268 |
| - if "config" in dtype_or_policy: |
269 |
| - dtype = dtype_or_policy.get("config").get("name") |
| 268 | + # `config.name`. Keras will export like this for non-input layers and |
| 269 | + # some other cases (e.g. tf/keras/mixed_precision/Policy, as described |
| 270 | + # in issue #5548). |
| 271 | + if isinstance(dtype_or_policy, dict) and "config" in dtype_or_policy: |
| 272 | + dtype = dtype_or_policy.get("config").get("name") |
270 | 273 | elif dtype_or_policy is not None:
|
271 | 274 | dtype = dtype_or_policy
|
| 275 | + |
272 | 276 | if dtype is not None:
|
273 |
| - tf_dtype = dtypes.as_dtype(dtype) |
274 |
| - node_def.attr["dtype"].type = tf_dtype.as_datatype_enum |
| 277 | + try: |
| 278 | + tf_dtype = dtypes.as_dtype(dtype) |
| 279 | + node_def.attr["dtype"].type = tf_dtype.as_datatype_enum |
| 280 | + except TypeError: |
| 281 | + has_unsupported_value = True |
| 282 | + elif dtype_or_policy is not None: |
| 283 | + has_unsupported_value = True |
| 284 | + |
| 285 | + if has_unsupported_value: |
| 286 | + # There's at least one known case when this happens, which is when |
| 287 | + # mixed precision dtype policies are used, as described in issue |
| 288 | + # #5548. (See https://keras.io/api/mixed_precision/). |
| 289 | + # There might be a better way to handle this, but here we are. |
| 290 | + logger.warning( |
| 291 | + "Unsupported dtype value in graph model config (json):\n%s", |
| 292 | + dtype_or_policy, |
| 293 | + ) |
275 | 294 | if layer.get("inbound_nodes") is not None:
|
276 | 295 | for name, size, index in _get_inbound_nodes(layer):
|
277 | 296 | inbound_name = _scoped_name(name_scope, name)
|
|
0 commit comments