Skip to content

Cannot properly save the model using TorchModuleWrapper and load the model #21350

Open
@lufeirider

Description

@lufeirider

environment

pip install keras==3.7.0

Save the model test

https://colab.research.google.com/drive/1b-fkujd3At131GGX4w5oj5XRWxbZwYXB?usp=sharing

Image

Load the model test

import keras
keras.models.load_model("/tmp/my_model/test.keras")

error info

Traceback (most recent call last):
  File "/demo-python/test_load.py", line 3, in <module>
    keras.models.load_model("/demo-python/test_keras/my_model/test.keras")
  File "/demo-python/.venv/lib/python3.9/site-packages/keras/src/saving/saving_api.py", line 189, in load_model
    return saving_lib.load_model(
  File "/demo-python/.venv/lib/python3.9/site-packages/keras/src/saving/saving_lib.py", line 370, in load_model
    return _load_model_from_fileobj(
  File "/demo-python/.venv/lib/python3.9/site-packages/keras/src/saving/saving_lib.py", line 447, in _load_model_from_fileobj
    model = _model_from_config(
  File "/demo-python/.venv/lib/python3.9/site-packages/keras/src/saving/saving_lib.py", line 436, in _model_from_config
    model = deserialize_keras_object(
  File "/demo-python/.venv/lib/python3.9/site-packages/keras/src/saving/serialization_lib.py", line 720, in deserialize_keras_object
    raise TypeError(
TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'build_config': {'input_shape': None}, 'class_name': 'Functional', 'config': {}, 'module': 'keras.src.models.functional', 'registered_name': 'Functional'}.

Exception encountered: <class 'keras.src.utils.torch_utils.TorchModuleWrapper'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'class_name': 'TorchModuleWrapper', 'config': {'dtype': {'class_name': 'DTypePolicy', 'config': {'name': 'mixed_float16'}, 'module': 'keras', 'registered_name': None}, 'module': {'class_name': '__bytes__', 'config': {'value': 'xxx'}}, 'name': 'torch_module_wrapper', 'trainable': True}, 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'dtype': 'float16', 'keras_history': ['cast', 0, 0], 'shape': [None, 1, 28, 28]}}], 'kwargs': {}}], 'module': 'keras.layers', 'name': 'torch_module_wrapper', 'registered_name': None}.

Exception encountered: a bytes-like object is required, not 'dict'

suggest

buffer = io.BytesIO(config["module"])

if type(config["module"]) == str:
    buff = bytes.fromhex(config["module"].replace("\\x", ""))
elif type(config["module"]) == bytes:
    buff = io.BytesIO(config["module"])

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions