Skip to content

Commit cbeecb7

Browse files
authored
Fix keras dtype importing and unpin for CI (tensorflow#6857)
Keras' output format was slightly changed in keras-team/keras#19711; for non-input layers dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras. Alternative to tensorflow#6855
1 parent 5f8b019 commit cbeecb7

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
- name: 'Install TensorFlow'
7777
run: |
7878
python -m pip install -U pip
79-
pip install "${TENSORFLOW_VERSION}" keras-nightly==3.3.3.dev2024051503
79+
pip install "${TENSORFLOW_VERSION}"
8080
if: matrix.tf_version_id != 'notf'
8181
- name: 'Install Python dependencies'
8282
run: |

tensorboard/plugins/graph/keras_util.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,19 @@ def keras_model_to_graph_def(keras_layer):
258258
node_def.attr["keras_class"].s = keras_cls_name
259259

260260
dtype_or_policy = layer_config.get("dtype")
261-
# Skip dtype processing if this is a dict, since it's presumably a instance of
262-
# tf/keras/mixed_precision/Policy rather than a single dtype.
261+
dtype = None
262+
# If this is a dict, try and extract the dtype string from
263+
# `config.name`; keras will export like this for non-input layers. If
264+
# we can't find `config.name`, we skip it as it's presumably a instance
265+
# of tf/keras/mixed_precision/Policy rather than a single dtype.
263266
# TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
264-
if dtype_or_policy is not None and not isinstance(
265-
dtype_or_policy, dict
266-
):
267-
tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
267+
if isinstance(dtype_or_policy, dict):
268+
if "config" in dtype_or_policy:
269+
dtype = dtype_or_policy.get("config").get("name")
270+
elif dtype_or_policy is not None:
271+
dtype = dtype_or_policy
272+
if dtype is not None:
273+
tf_dtype = dtypes.as_dtype(dtype)
268274
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
269275
if layer.get("inbound_nodes") is not None:
270276
for name, size, index in _get_inbound_nodes(layer):

0 commit comments

Comments
 (0)