Skip to content

Commit ae7d0b9

Browse files
authored
Fixes regression for issue tensorflow#5548: Avoid attempting to convert dtypes from "mixed precision" policy types. (tensorflow#6859)
Following-up on PR tensorflow#6857, which seems to have introduced a regression for issue tensorflow#5548. This change just gracefully degrades the functionality to avoid crashing on an error (as it was before the recent change in tensorflow#6857), but it might not be a proper fix for the scenario described in that issue.
1 parent cbeecb7 commit ae7d0b9

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

tensorboard/plugins/graph/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ py_library(
114114
deps = [
115115
"//tensorboard/compat/proto:protos_all_py_pb2",
116116
"//tensorboard/compat/tensorflow_stub",
117+
"//tensorboard/util:tb_logging",
117118
],
118119
)
119120

tensorboard/plugins/graph/keras_util.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
"""
4444
from tensorboard.compat.proto.graph_pb2 import GraphDef
4545
from tensorboard.compat.tensorflow_stub import dtypes
46+
from tensorboard.util import tb_logging
47+
48+
49+
logger = tb_logging.get_logger()
4650

4751

4852
def _walk_layers(keras_layer):
@@ -259,19 +263,34 @@ def keras_model_to_graph_def(keras_layer):
259263

260264
dtype_or_policy = layer_config.get("dtype")
261265
dtype = None
266+
has_unsupported_value = False
262267
# If this is a dict, try and extract the dtype string from
263-
# `config.name`; keras will export like this for non-input layers. If
264-
# we can't find `config.name`, we skip it as it's presumably a instance
265-
# of tf/keras/mixed_precision/Policy rather than a single dtype.
266-
# TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
267-
if isinstance(dtype_or_policy, dict):
268-
if "config" in dtype_or_policy:
269-
dtype = dtype_or_policy.get("config").get("name")
268+
# `config.name`. Keras will export like this for non-input layers and
269+
# some other cases (e.g. tf/keras/mixed_precision/Policy, as described
270+
# in issue #5548).
271+
if isinstance(dtype_or_policy, dict) and "config" in dtype_or_policy:
272+
dtype = dtype_or_policy.get("config").get("name")
270273
elif dtype_or_policy is not None:
271274
dtype = dtype_or_policy
275+
272276
if dtype is not None:
273-
tf_dtype = dtypes.as_dtype(dtype)
274-
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
277+
try:
278+
tf_dtype = dtypes.as_dtype(dtype)
279+
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
280+
except TypeError:
281+
has_unsupported_value = True
282+
elif dtype_or_policy is not None:
283+
has_unsupported_value = True
284+
285+
if has_unsupported_value:
286+
# There's at least one known case when this happens, which is when
287+
# mixed precision dtype policies are used, as described in issue
288+
# #5548. (See https://keras.io/api/mixed_precision/).
289+
# There might be a better way to handle this, but here we are.
290+
logger.warning(
291+
"Unsupported dtype value in graph model config (json):\n%s",
292+
dtype_or_policy,
293+
)
275294
if layer.get("inbound_nodes") is not None:
276295
for name, size, index in _get_inbound_nodes(layer):
277296
inbound_name = _scoped_name(name_scope, name)

tensorboard/plugins/graph/keras_util_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,20 @@ def test_keras_model_to_graph_def_functional_multiple_inbound_nodes_from_same_no
10431043

10441044
self.assertGraphDefToModel(expected_proto, model)
10451045

1046+
def test__keras_model_to_graph_def__does_not_crash_with_mixed_precision_dtype_policy(
1047+
self,
1048+
):
1049+
# See https://keras.io/api/mixed_precision/ for more info.
1050+
# Test to avoid regression on issue #5548
1051+
first_layer = tf.keras.layers.Dense(
1052+
1, input_shape=(1,), dtype="mixed_float16"
1053+
)
1054+
model = tf.keras.Sequential([first_layer])
1055+
1056+
model_config = json.loads(model.to_json())
1057+
# This line should not raise errors:
1058+
keras_util.keras_model_to_graph_def(model_config)
1059+
10461060

10471061
class _DoublingLayer(tf.keras.layers.Layer):
10481062
def call(self, inputs):

0 commit comments

Comments
 (0)