Skip to content

Commit 20b1866

Browse files
jaycee-licopybara-github
authored andcommitted
fix: AttributeError for TorchModelSerializer.deserialize in torch >=2.3.0
PiperOrigin-RevId: 631215839
1 parent 195c77e commit 20b1866

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

vertexai/preview/_workflow/serialization_engine/serializers.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def _is_valid_gcs_path(path: str) -> bool:
144144

145145
def _load_torch_model(path: str, map_location: "torch.device") -> "torch.nn.Module":
146146
import torch
147+
147148
try:
148149
return torch.load(path, map_location=map_location)
149150
except Exception:
@@ -434,7 +435,9 @@ class TorchModelSerializer(serializers_base.Serializer):
434435
serializers_base.SerializationMetadata(serializer="TorchModelSerializer")
435436
)
436437

437-
def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) -> str:
438+
def serialize(
439+
self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs
440+
) -> str:
438441
"""Serializes a torch.nn.Module to a gcs path.
439442
440443
Args:
@@ -450,6 +453,7 @@ def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) ->
450453
ValueError: if `gcs_path` is not a valid GCS uri.
451454
"""
452455
import torch
456+
453457
del kwargs
454458
if not _is_valid_gcs_path(gcs_path):
455459
raise ValueError(f"Invalid gcs path: {gcs_path}")
@@ -500,11 +504,18 @@ def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.nn.Module":
500504
except ImportError as e:
501505
raise ImportError("torch is not installed.") from e
502506

503-
map_location = (
504-
torch._GLOBAL_DEVICE_CONTEXT.device
505-
if torch._GLOBAL_DEVICE_CONTEXT
506-
else None
507-
)
507+
# Get the default device in the local torch environment.
508+
# If `set_default_device` hasn't been called, _GLOBAL_DEVICE_CONTEXT
509+
# should be None, then we set map_location to None as well.
510+
map_location = None
511+
# In torch 2.3.0, get_default_device is introduced
512+
if hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device_context") and hasattr(
513+
torch, "get_default_device"
514+
):
515+
map_location = torch.get_default_device()
516+
# For older versions, we get default device from _GLOBAL_DEVICE_CONTEXT
517+
elif hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device"):
518+
map_location = torch._GLOBAL_DEVICE_CONTEXT.device
508519

509520
if serialized_gcs_path.startswith("gs://"):
510521
with tempfile.NamedTemporaryFile() as temp_file:
@@ -731,7 +742,9 @@ class TorchDataLoaderSerializer(serializers_base.Serializer):
731742
serializers_base.SerializationMetadata(serializer="TorchDataLoaderSerializer")
732743
)
733744

734-
def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: str):
745+
def _serialize_to_local(
746+
self, to_serialize: "torch.utils.data.DataLoader", path: str
747+
):
735748
"""Serializes a torch.utils.data.DataLoader to a local path.
736749
737750
Args:
@@ -778,6 +791,7 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path:
778791
# for default batch sampler we store batch_size, drop_last, and sampler object
779792
# but not batch sampler object.
780793
import torch
794+
781795
if isinstance(to_serialize.batch_sampler, torch.utils.data.BatchSampler):
782796
pass_through_args["batch_size"] = to_serialize.batch_size
783797
pass_through_args["drop_last"] = to_serialize.drop_last
@@ -797,7 +811,9 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path:
797811
with open(f"{path}/pass_through_args.json", "w") as f:
798812
json.dump(pass_through_args, f)
799813

800-
def serialize(self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs) -> str:
814+
def serialize(
815+
self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs
816+
) -> str:
801817
"""Serializes a torch.utils.data.DataLoader to a gcs path.
802818
803819
Args:
@@ -883,7 +899,9 @@ def _deserialize_from_local(self, path: str) -> "torch.utils.data.DataLoader":
883899

884900
return torch.utils.data.DataLoader(**kwargs)
885901

886-
def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.utils.data.DataLoader":
902+
def deserialize(
903+
self, serialized_gcs_path: str, **kwargs
904+
) -> "torch.utils.data.DataLoader":
887905
"""Deserialize a torch.utils.data.DataLoader given the gcs path.
888906
889907
Args:

0 commit comments

Comments
 (0)