Skip to content

Commit 1341e2c

Browse files
yinghsienwucopybara-github
authored andcommitted
fix: Register TensorFlow models from Ray checkpoints for more recent TensorFlow version, addressing the deprecation of SavedModel format in keras 3
PiperOrigin-RevId: 628562509
1 parent 9809a3a commit 1341e2c

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.cloud import aiplatform
2525
from google.cloud.aiplatform import initializer
2626
from google.cloud.aiplatform import utils
27+
from google.cloud.aiplatform.preview.vertex_ray.predict.util import constants
2728
from google.cloud.aiplatform.preview.vertex_ray.predict.util import (
2829
predict_utils,
2930
)
@@ -44,6 +45,7 @@ def register_tensorflow(
4445
artifact_uri: Optional[str] = None,
4546
_model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None,
4647
display_name: Optional[str] = None,
48+
tensorflow_version: Optional[str] = None,
4749
**kwargs,
4850
) -> aiplatform.Model:
4951
"""Uploads a Ray Tensorflow Checkpoint as Tensorflow Model to Model Registry.
@@ -79,6 +81,11 @@ def create_model():
7981
display_name (str):
8082
Optional. The display name of the Model. The name can be up to 128
8183
characters long and can be consist of any UTF-8 characters.
84+
tensorflow_version (str):
85+
Optional. The version of the Tensorflow serving container.
86+
Supported versions:
87+
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
88+
If the version is not specified, the latest version is used.
8289
**kwargs:
8390
Any kwargs will be passed to aiplatform.Model registration.
8491
@@ -89,6 +96,9 @@ def create_model():
8996
Raises:
9097
ValueError: Invalid Argument.
9198
"""
99+
100+
if tensorflow_version is None:
101+
tensorflow_version = constants._TENSORFLOW_VERSION
92102
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
93103
predict_utils.validate_artifact_uri(artifact_uri)
94104
prefix = "ray-on-vertex-registered-tensorflow-model"
@@ -99,10 +109,16 @@ def create_model():
99109
)
100110
tf_model = _get_tensorflow_model_from(checkpoint, model=_model)
101111
model_dir = os.path.join(artifact_uri, prefix)
102-
tf_model.save(model_dir)
112+
try:
113+
import tensorflow as tf
114+
115+
tf.saved_model.save(tf_model, model_dir)
116+
except ImportError:
117+
logging.warning("TensorFlow must be installed to save the trained model.")
103118
return aiplatform.Model.upload_tensorflow_saved_model(
104119
saved_model_dir=model_dir,
105120
display_name=display_model_name,
121+
tensorflow_version=tensorflow_version,
106122
**kwargs,
107123
)
108124

@@ -139,13 +155,13 @@ def _get_tensorflow_model_from(
139155

140156
return checkpoint.get_model(model)
141157

142-
# get_model() signature changed in future versions
143158
try:
144-
from tensorflow import keras
159+
import tensorflow as tf
145160

146161
try:
147-
return keras.models.load_model(checkpoint.path)
162+
return tf.saved_model.load(checkpoint.path)
148163
except OSError:
149-
return keras.models.load_model("gs://" + checkpoint.path)
164+
return tf.saved_model.load("gs://" + checkpoint.path)
165+
150166
except ImportError:
151167
logging.warning("TensorFlow must be installed to load the trained model.")

google/cloud/aiplatform/preview/vertex_ray/predict/util/constants.py

+5
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,8 @@
2121
_PICKLE_EXTENTION = ".pkl"
2222

2323
_XGBOOST_VERSION = "1.6"
24+
# TensorFlow 2.13 requires typing_extensions<4.6 and will cause errors in Ray.
25+
# https://github.com/tensorflow/tensorflow/blob/v2.13.0/tensorflow/tools/pip_package/setup.py#L100
26+
# 2.13 is the latest supported version of Vertex prebuilt prediction container.
27+
# Set 2.12 as default here since 2.13 cause errors.
28+
_TENSORFLOW_VERSION = "2.12"

google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def register_xgboost(
7777
Optional. The display name of the Model. The name can be up to 128
7878
characters long and can be consist of any UTF-8 characters.
7979
xgboost_version (str): Optional. The version of the XGBoost serving container.
80-
Supported versions: ["0.82", "0.90", "1.1", "1.2", "1.3", "1.4", "1.6", "1.7"].
81-
If the version is not specified, the latest version is used.
80+
Supported versions:
81+
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
82+
If the version is not specified, the version 1.6 is used.
8283
**kwargs:
8384
Any kwargs will be passed to aiplatform.Model registration.
8485

0 commit comments

Comments
 (0)