25
25
import pickle
26
26
import shutil
27
27
import tempfile
28
- from typing import Any , Dict , Optional , Union
28
+ from typing import Any , Dict , Optional , Union , TYPE_CHECKING
29
29
import uuid
30
30
31
31
from google .cloud .aiplatform .utils import gcs_utils
48
48
49
49
SERIALIZATION_METADATA_FRAMEWORK_KEY = "framework"
50
50
51
- try :
52
- from tensorflow import keras
53
- import tensorflow as tf
51
+ if TYPE_CHECKING :
52
+ try :
53
+ from tensorflow import keras
54
+ import tensorflow as tf
54
55
55
- KerasModel = keras .models .Model
56
- TFDataset = tf .data .Dataset
57
- except ImportError :
58
- keras = None
59
- tf = None
60
- KerasModel = Any
61
- TFDataset = Any
56
+ KerasModel = keras .models .Model
57
+ TFDataset = tf .data .Dataset
58
+ except ImportError :
59
+ keras = None
60
+ tf = None
61
+ KerasModel = Any
62
+ TFDataset = Any
62
63
63
64
try :
64
65
import torch
@@ -184,7 +185,7 @@ class KerasModelSerializer(serializers_base.Serializer):
184
185
)
185
186
186
187
def serialize (
187
- self , to_serialize : KerasModel , gcs_path : str , ** kwargs
188
+ self , to_serialize : "keras.models.Model" , gcs_path : str , ** kwargs # noqa: F821
188
189
) -> str : # pytype: disable=invalid-annotation
189
190
"""Serializes a tensorflow.keras.models.Model to a gcs path.
190
191
@@ -232,7 +233,9 @@ def serialize(
232
233
to_serialize .save (gcs_path , save_format = save_format )
233
234
return gcs_path
234
235
235
- def deserialize (self , serialized_gcs_path : str , ** kwargs ) -> KerasModel :
236
+ def deserialize (
237
+ self , serialized_gcs_path : str , ** kwargs
238
+ ) -> "keras.models.Model" : # noqa: F821
236
239
"""Deserialize a tensorflow.keras.models.Model given the gcs file name.
237
240
238
241
Args:
@@ -335,6 +338,7 @@ def deserialize(self, serialized_gcs_path: str, **kwargs):
335
338
Raises:
336
339
ValueError: if `serialized_gcs_path` is not a valid GCS uri.
337
340
"""
341
+ from tensorflow import keras
338
342
339
343
if not _is_valid_gcs_path (serialized_gcs_path ):
340
344
raise ValueError (f"Invalid gcs path: { serialized_gcs_path } " )
@@ -922,8 +926,12 @@ class TFDatasetSerializer(serializers_base.Serializer):
922
926
serializers_base .SerializationMetadata (serializer = "TFDatasetSerializer" )
923
927
)
924
928
925
- def serialize (self , to_serialize : TFDataset , gcs_path : str , ** kwargs ) -> str :
929
+ def serialize (
930
+ self , to_serialize : "tf.data.Dataset" , gcs_path : str , ** kwargs # noqa: F821
931
+ ) -> str : # noqa: F821
926
932
del kwargs
933
+ import tensorflow as tf
934
+
927
935
if not _is_valid_gcs_path (gcs_path ):
928
936
raise ValueError (f"Invalid gcs path: { gcs_path } " )
929
937
TFDatasetSerializer ._metadata .dependencies = (
@@ -936,8 +944,12 @@ def serialize(self, to_serialize: TFDataset, gcs_path: str, **kwargs) -> str:
936
944
tf .data .experimental .save (to_serialize , gcs_path )
937
945
return gcs_path
938
946
939
- def deserialize (self , serialized_gcs_path : str , ** kwargs ) -> TFDataset :
947
+ def deserialize (
948
+ self , serialized_gcs_path : str , ** kwargs
949
+ ) -> "tf.data.Dataset" : # noqa: F821
940
950
del kwargs
951
+ import tensorflow as tf
952
+
941
953
try :
942
954
deserialized = tf .data .Dataset .load (serialized_gcs_path )
943
955
except AttributeError :
@@ -1180,6 +1192,11 @@ def serialize(
1180
1192
return gcs_path
1181
1193
1182
1194
def _get_tfio_verison (self ):
1195
+ import tensorflow as tf
1196
+
1197
+ if tf .__version__ < "2.13.0" :
1198
+ raise ValueError ("TensorFlow version < 2.13.0 is not supported." )
1199
+
1183
1200
major , minor , _ = version .Version (tf .__version__ ).release
1184
1201
tf_version = f"{ major } .{ minor } "
1185
1202
@@ -1277,7 +1294,7 @@ def _deserialize_tensorflow(
1277
1294
serialized_gcs_path : str ,
1278
1295
batch_size : Optional [int ] = None ,
1279
1296
target_col : Optional [str ] = None ,
1280
- ) -> TFDataset :
1297
+ ) -> "tf.data.Dataset" : # noqa: F821
1281
1298
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
1282
1299
1283
1300
serialized_gcs_path is a folder containing one or more parquet files.
@@ -1287,6 +1304,11 @@ def _deserialize_tensorflow(
1287
1304
target_col = target_col .encode ("ASCII" ) if target_col else b"target"
1288
1305
1289
1306
# Deserialization at remote environment
1307
+ import tensorflow as tf
1308
+
1309
+ if tf .__version__ < "2.13.0" :
1310
+ raise ValueError ("TensorFlow version < 2.13.0 is not supported." )
1311
+
1290
1312
try :
1291
1313
import tensorflow_io as tfio
1292
1314
except ImportError as e :
0 commit comments