Skip to content

Commit f294ba8

Browse files
yinghsienwucopybara-github
authored andcommitted
fix: Improve import time by moving TensorFlow to lazy import
PiperOrigin-RevId: 613018264
1 parent 2690e72 commit f294ba8

File tree

4 files changed

+50
-31
lines changed

4 files changed

+50
-31
lines changed

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@
142142
"pytest-asyncio",
143143
"pytest-xdist",
144144
"scikit-learn",
145-
"tensorflow >= 2.3.0, <= 2.12.0",
145+
# Lazy import requires > 2.12.0
146+
"tensorflow == 2.13.0",
146147
# TODO(jayceeli) torch 2.1.0 has conflict with pyfakefs, will check if
147148
# future versions fix this issue
148149
"torch >= 2.0.0, < 2.1.0",

vertexai/preview/_workflow/executor/training_script.py

-10
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@
3333
from vertexai.preview.developer import remote_specs
3434

3535

36-
try:
37-
# This line ensures a tensorflow model to be loaded by cloudpickle correctly
38-
# We put it in a try clause since not all models are tensorflow and if it is
39-
# a tensorflow model, the dependency should've been installed and therefore
40-
# import should work.
41-
import tensorflow as tf # noqa: F401
42-
except ImportError:
43-
pass
44-
45-
4636
os.environ["_IS_VERTEX_REMOTE_TRAINING"] = "True"
4737

4838
print(constants._START_EXECUTION_MSG)

vertexai/preview/_workflow/serialization_engine/serializers.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pickle
2626
import shutil
2727
import tempfile
28-
from typing import Any, Dict, Optional, Union
28+
from typing import Any, Dict, Optional, Union, TYPE_CHECKING
2929
import uuid
3030

3131
from google.cloud.aiplatform.utils import gcs_utils
@@ -48,17 +48,18 @@
4848

4949
SERIALIZATION_METADATA_FRAMEWORK_KEY = "framework"
5050

51-
try:
52-
from tensorflow import keras
53-
import tensorflow as tf
51+
if TYPE_CHECKING:
52+
try:
53+
from tensorflow import keras
54+
import tensorflow as tf
5455

55-
KerasModel = keras.models.Model
56-
TFDataset = tf.data.Dataset
57-
except ImportError:
58-
keras = None
59-
tf = None
60-
KerasModel = Any
61-
TFDataset = Any
56+
KerasModel = keras.models.Model
57+
TFDataset = tf.data.Dataset
58+
except ImportError:
59+
keras = None
60+
tf = None
61+
KerasModel = Any
62+
TFDataset = Any
6263

6364
try:
6465
import torch
@@ -184,7 +185,7 @@ class KerasModelSerializer(serializers_base.Serializer):
184185
)
185186

186187
def serialize(
187-
self, to_serialize: KerasModel, gcs_path: str, **kwargs
188+
self, to_serialize: "keras.models.Model", gcs_path: str, **kwargs # noqa: F821
188189
) -> str: # pytype: disable=invalid-annotation
189190
"""Serializes a tensorflow.keras.models.Model to a gcs path.
190191
@@ -232,7 +233,9 @@ def serialize(
232233
to_serialize.save(gcs_path, save_format=save_format)
233234
return gcs_path
234235

235-
def deserialize(self, serialized_gcs_path: str, **kwargs) -> KerasModel:
236+
def deserialize(
237+
self, serialized_gcs_path: str, **kwargs
238+
) -> "keras.models.Model": # noqa: F821
236239
"""Deserialize a tensorflow.keras.models.Model given the gcs file name.
237240
238241
Args:
@@ -335,6 +338,7 @@ def deserialize(self, serialized_gcs_path: str, **kwargs):
335338
Raises:
336339
ValueError: if `serialized_gcs_path` is not a valid GCS uri.
337340
"""
341+
from tensorflow import keras
338342

339343
if not _is_valid_gcs_path(serialized_gcs_path):
340344
raise ValueError(f"Invalid gcs path: {serialized_gcs_path}")
@@ -922,8 +926,12 @@ class TFDatasetSerializer(serializers_base.Serializer):
922926
serializers_base.SerializationMetadata(serializer="TFDatasetSerializer")
923927
)
924928

925-
def serialize(self, to_serialize: TFDataset, gcs_path: str, **kwargs) -> str:
929+
def serialize(
930+
self, to_serialize: "tf.data.Dataset", gcs_path: str, **kwargs # noqa: F821
931+
) -> str: # noqa: F821
926932
del kwargs
933+
import tensorflow as tf
934+
927935
if not _is_valid_gcs_path(gcs_path):
928936
raise ValueError(f"Invalid gcs path: {gcs_path}")
929937
TFDatasetSerializer._metadata.dependencies = (
@@ -936,8 +944,12 @@ def serialize(self, to_serialize: TFDataset, gcs_path: str, **kwargs) -> str:
936944
tf.data.experimental.save(to_serialize, gcs_path)
937945
return gcs_path
938946

939-
def deserialize(self, serialized_gcs_path: str, **kwargs) -> TFDataset:
947+
def deserialize(
948+
self, serialized_gcs_path: str, **kwargs
949+
) -> "tf.data.Dataset": # noqa: F821
940950
del kwargs
951+
import tensorflow as tf
952+
941953
try:
942954
deserialized = tf.data.Dataset.load(serialized_gcs_path)
943955
except AttributeError:
@@ -1180,6 +1192,11 @@ def serialize(
11801192
return gcs_path
11811193

11821194
def _get_tfio_verison(self):
1195+
import tensorflow as tf
1196+
1197+
if tf.__version__ < "2.13.0":
1198+
raise ValueError("TensorFlow version < 2.13.0 is not supported.")
1199+
11831200
major, minor, _ = version.Version(tf.__version__).release
11841201
tf_version = f"{major}.{minor}"
11851202

@@ -1277,7 +1294,7 @@ def _deserialize_tensorflow(
12771294
serialized_gcs_path: str,
12781295
batch_size: Optional[int] = None,
12791296
target_col: Optional[str] = None,
1280-
) -> TFDataset:
1297+
) -> "tf.data.Dataset": # noqa: F821
12811298
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
12821299
12831300
serialized_gcs_path is a folder containing one or more parquet files.
@@ -1287,6 +1304,11 @@ def _deserialize_tensorflow(
12871304
target_col = target_col.encode("ASCII") if target_col else b"target"
12881305

12891306
# Deserialization at remote environment
1307+
import tensorflow as tf
1308+
1309+
if tf.__version__ < "2.13.0":
1310+
raise ValueError("TensorFlow version < 2.13.0 is not supported.")
1311+
12901312
try:
12911313
import tensorflow_io as tfio
12921314
except ImportError as e:

vertexai/preview/developer/remote_specs.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@
3434
serializers,
3535
)
3636

37-
try:
38-
import tensorflow as tf
39-
except ImportError:
40-
pass
4137
try:
4238
import torch
4339
except ImportError:
@@ -763,6 +759,11 @@ def _get_keras_distributed_strategy(enable_distributed: bool, accelerator_count:
763759
Returns:
764760
A tf.distribute.Strategy.
765761
"""
762+
import tensorflow as tf
763+
764+
if tf.__version__ < "2.13.0":
765+
raise ValueError("TensorFlow version < 2.13.0 is not supported.")
766+
766767
if enable_distributed:
767768
cluster_spec = _get_cluster_spec()
768769
# Multiple workers, use tf.distribute.MultiWorkerMirroredStrategy().
@@ -793,6 +794,11 @@ def _set_keras_distributed_strategy(model: Any, strategy: Any):
793794
A tf.distribute.Strategy.
794795
"""
795796
# Clone and compile model within scope of chosen strategy.
797+
import tensorflow as tf
798+
799+
if tf.__version__ < "2.13.0":
800+
raise ValueError("TensorFlow version < 2.13.0 is not supported.")
801+
796802
with strategy.scope():
797803
cloned_model = tf.keras.models.clone_model(model)
798804
cloned_model.compile_from_config(model.get_compile_config())

0 commit comments

Comments
 (0)