@@ -144,6 +144,7 @@ def _is_valid_gcs_path(path: str) -> bool:
144
144
145
145
def _load_torch_model (path : str , map_location : "torch.device" ) -> "torch.nn.Module" :
146
146
import torch
147
+
147
148
try :
148
149
return torch .load (path , map_location = map_location )
149
150
except Exception :
@@ -434,7 +435,9 @@ class TorchModelSerializer(serializers_base.Serializer):
434
435
serializers_base .SerializationMetadata (serializer = "TorchModelSerializer" )
435
436
)
436
437
437
- def serialize (self , to_serialize : "torch.nn.Module" , gcs_path : str , ** kwargs ) -> str :
438
+ def serialize (
439
+ self , to_serialize : "torch.nn.Module" , gcs_path : str , ** kwargs
440
+ ) -> str :
438
441
"""Serializes a torch.nn.Module to a gcs path.
439
442
440
443
Args:
@@ -450,6 +453,7 @@ def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) ->
450
453
ValueError: if `gcs_path` is not a valid GCS uri.
451
454
"""
452
455
import torch
456
+
453
457
del kwargs
454
458
if not _is_valid_gcs_path (gcs_path ):
455
459
raise ValueError (f"Invalid gcs path: { gcs_path } " )
@@ -500,11 +504,18 @@ def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.nn.Module":
500
504
except ImportError as e :
501
505
raise ImportError ("torch is not installed." ) from e
502
506
503
- map_location = (
504
- torch ._GLOBAL_DEVICE_CONTEXT .device
505
- if torch ._GLOBAL_DEVICE_CONTEXT
506
- else None
507
- )
507
+ # Get the default device in the local torch environment.
508
+ # If `set_default_device` hasn't been called, _GLOBAL_DEVICE_CONTEXT
509
+ # should be None, then we set map_location to None as well.
510
+ map_location = None
511
+ # In torch 2.3.0, get_default_device is introduced
512
+ if hasattr (torch ._GLOBAL_DEVICE_CONTEXT , "device_context" ) and hasattr (
513
+ torch , "get_default_device"
514
+ ):
515
+ map_location = torch .get_default_device ()
516
+ # For older versions, we get default device from _GLOBAL_DEVICE_CONTEXT
517
+ elif hasattr (torch ._GLOBAL_DEVICE_CONTEXT , "device" ):
518
+ map_location = torch ._GLOBAL_DEVICE_CONTEXT .device
508
519
509
520
if serialized_gcs_path .startswith ("gs://" ):
510
521
with tempfile .NamedTemporaryFile () as temp_file :
@@ -731,7 +742,9 @@ class TorchDataLoaderSerializer(serializers_base.Serializer):
731
742
serializers_base .SerializationMetadata (serializer = "TorchDataLoaderSerializer" )
732
743
)
733
744
734
- def _serialize_to_local (self , to_serialize : "torch.utils.data.DataLoader" , path : str ):
745
+ def _serialize_to_local (
746
+ self , to_serialize : "torch.utils.data.DataLoader" , path : str
747
+ ):
735
748
"""Serializes a torch.utils.data.DataLoader to a local path.
736
749
737
750
Args:
@@ -778,6 +791,7 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path:
778
791
# for default batch sampler we store batch_size, drop_last, and sampler object
779
792
# but not batch sampler object.
780
793
import torch
794
+
781
795
if isinstance (to_serialize .batch_sampler , torch .utils .data .BatchSampler ):
782
796
pass_through_args ["batch_size" ] = to_serialize .batch_size
783
797
pass_through_args ["drop_last" ] = to_serialize .drop_last
@@ -797,7 +811,9 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path:
797
811
with open (f"{ path } /pass_through_args.json" , "w" ) as f :
798
812
json .dump (pass_through_args , f )
799
813
800
- def serialize (self , to_serialize : "torch.utils.data.DataLoader" , gcs_path : str , ** kwargs ) -> str :
814
+ def serialize (
815
+ self , to_serialize : "torch.utils.data.DataLoader" , gcs_path : str , ** kwargs
816
+ ) -> str :
801
817
"""Serializes a torch.utils.data.DataLoader to a gcs path.
802
818
803
819
Args:
@@ -883,7 +899,9 @@ def _deserialize_from_local(self, path: str) -> "torch.utils.data.DataLoader":
883
899
884
900
return torch .utils .data .DataLoader (** kwargs )
885
901
886
- def deserialize (self , serialized_gcs_path : str , ** kwargs ) -> "torch.utils.data.DataLoader" :
902
+ def deserialize (
903
+ self , serialized_gcs_path : str , ** kwargs
904
+ ) -> "torch.utils.data.DataLoader" :
887
905
"""Deserialize a torch.utils.data.DataLoader given the gcs path.
888
906
889
907
Args:
0 commit comments