24
24
from google .cloud import aiplatform
25
25
from google .cloud .aiplatform import initializer
26
26
from google .cloud .aiplatform import utils
27
+ from google .cloud .aiplatform .preview .vertex_ray .predict .util import constants
27
28
from google .cloud .aiplatform .preview .vertex_ray .predict .util import (
28
29
predict_utils ,
29
30
)
@@ -44,6 +45,7 @@ def register_tensorflow(
44
45
artifact_uri : Optional [str ] = None ,
45
46
_model : Optional [Union ["tf.keras.Model" , Callable [[], "tf.keras.Model" ]]] = None ,
46
47
display_name : Optional [str ] = None ,
48
+ tensorflow_version : Optional [str ] = None ,
47
49
** kwargs ,
48
50
) -> aiplatform .Model :
49
51
"""Uploads a Ray Tensorflow Checkpoint as Tensorflow Model to Model Registry.
@@ -79,6 +81,11 @@ def create_model():
79
81
display_name (str):
80
82
Optional. The display name of the Model. The name can be up to 128
81
83
characters long and can be consist of any UTF-8 characters.
84
+ tensorflow_version (str):
85
+ Optional. The version of the Tensorflow serving container.
86
+ Supported versions:
87
+ https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
88
+ If the version is not specified, the latest version is used.
82
89
**kwargs:
83
90
Any kwargs will be passed to aiplatform.Model registration.
84
91
@@ -89,6 +96,9 @@ def create_model():
89
96
Raises:
90
97
ValueError: Invalid Argument.
91
98
"""
99
+
100
+ if tensorflow_version is None :
101
+ tensorflow_version = constants ._TENSORFLOW_VERSION
92
102
artifact_uri = artifact_uri or initializer .global_config .staging_bucket
93
103
predict_utils .validate_artifact_uri (artifact_uri )
94
104
prefix = "ray-on-vertex-registered-tensorflow-model"
@@ -99,10 +109,16 @@ def create_model():
99
109
)
100
110
tf_model = _get_tensorflow_model_from (checkpoint , model = _model )
101
111
model_dir = os .path .join (artifact_uri , prefix )
102
- tf_model .save (model_dir )
112
+ try :
113
+ import tensorflow as tf
114
+
115
+ tf .saved_model .save (tf_model , model_dir )
116
+ except ImportError :
117
+ logging .warning ("TensorFlow must be installed to save the trained model." )
103
118
return aiplatform .Model .upload_tensorflow_saved_model (
104
119
saved_model_dir = model_dir ,
105
120
display_name = display_model_name ,
121
+ tensorflow_version = tensorflow_version ,
106
122
** kwargs ,
107
123
)
108
124
@@ -139,13 +155,13 @@ def _get_tensorflow_model_from(
139
155
140
156
return checkpoint .get_model (model )
141
157
142
- # get_model() signature changed in future versions
143
158
try :
144
- from tensorflow import keras
159
+ import tensorflow as tf
145
160
146
161
try :
147
- return keras . models . load_model (checkpoint .path )
162
+ return tf . saved_model . load (checkpoint .path )
148
163
except OSError :
149
- return keras .models .load_model ("gs://" + checkpoint .path )
164
+ return tf .saved_model .load ("gs://" + checkpoint .path )
165
+
150
166
except ImportError :
151
167
logging .warning ("TensorFlow must be installed to load the trained model." )
0 commit comments