Skip to content

Commit d4deed3

Browse files
jaycee-licopybara-github
authored andcommitted
feat: Support Model Serialization in Vertex Experiments(sklearn)
PiperOrigin-RevId: 501487417
1 parent 94b2f29 commit d4deed3

14 files changed

+1830
-16
lines changed

google/cloud/aiplatform/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,17 @@
9191
log_classification_metrics = (
9292
metadata.metadata._experiment_tracker.log_classification_metrics
9393
)
94+
log_model = metadata.metadata._experiment_tracker.log_model
9495
get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
9596
start_run = metadata.metadata._experiment_tracker.start_run
9697
start_execution = metadata.metadata._experiment_tracker.start_execution
9798
log = metadata.metadata._experiment_tracker.log
9899
log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics
99100
end_run = metadata.metadata._experiment_tracker.end_run
100101

102+
save_model = metadata._models.save_model
103+
get_experiment_model = metadata.schema.google.artifact_schema.ExperimentModel.get
104+
101105
Experiment = metadata.experiment_resources.Experiment
102106
ExperimentRun = metadata.experiment_run_resource.ExperimentRun
103107
Artifact = metadata.artifact.Artifact
@@ -116,11 +120,14 @@
116120
"log_params",
117121
"log_metrics",
118122
"log_classification_metrics",
123+
"log_model",
119124
"log_time_series_metrics",
120125
"get_experiment_df",
121126
"get_pipeline_df",
122127
"start_run",
123128
"start_execution",
129+
"save_model",
130+
"get_experiment_model",
124131
"Artifact",
125132
"AutoMLImageTrainingJob",
126133
"AutoMLTabularTrainingJob",

google/cloud/aiplatform/helpers/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020
is_prebuilt_prediction_container_uri = (
2121
container_uri_builders.is_prebuilt_prediction_container_uri
2222
)
23+
_get_closest_match_prebuilt_container_uri = (
24+
container_uri_builders._get_closest_match_prebuilt_container_uri
25+
)
2326

2427
__all__ = (
2528
"get_prebuilt_prediction_container_uri",
2629
"is_prebuilt_prediction_container_uri",
30+
"_get_closest_match_prebuilt_container_uri",
2731
)

google/cloud/aiplatform/helpers/container_uri_builders.py

+104-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021 Google LLC
1+
# Copyright 2022 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,9 +14,11 @@
1414

1515
import re
1616
from typing import Optional
17+
import warnings
1718

18-
from google.cloud.aiplatform.constants import prediction
1919
from google.cloud.aiplatform import initializer
20+
from google.cloud.aiplatform.constants import prediction
21+
from packaging import version
2022

2123

2224
def get_prebuilt_prediction_container_uri(
@@ -122,3 +124,103 @@ def is_prebuilt_prediction_container_uri(image_uri: str) -> bool:
122124
If the image is prebuilt by Vertex AI prediction.
123125
"""
124126
return re.fullmatch(prediction.CONTAINER_URI_REGEX, image_uri) is not None
127+
128+
129+
# TODO(b/264191784) Deduplicate this method
130+
def _get_closest_match_prebuilt_container_uri(
131+
framework: str,
132+
framework_version: str,
133+
region: Optional[str] = None,
134+
accelerator: str = "cpu",
135+
) -> str:
136+
"""Return a pre-built container uri that is suitable for a specific framework and version.
137+
138+
If there is no exact match for the given version, the closest one that is
139+
higher than the input version will be used.
140+
141+
Args:
142+
framework (str):
143+
Required. The ML framework of the pre-built container. For example,
144+
`"tensorflow"`, `"xgboost"`, or `"sklearn"`
145+
framework_version (str):
146+
Required. The version of the specified ML framework as a string.
147+
region (str):
148+
Optional. AI region or multi-region. Used to select the correct
149+
Artifact Registry multi-region repository and reduce latency.
150+
Must start with `"us"`, `"asia"` or `"europe"`.
151+
Default is location set by `aiplatform.init()`.
152+
accelerator (str):
153+
Optional. The type of accelerator support provided by container. For
154+
example: `"cpu"` or `"gpu"`
155+
Default is `"cpu"`.
156+
157+
Returns:
158+
A string representing the pre-built container uri.
159+
160+
Raises:
161+
ValueError: If the framework doesn't have suitable pre-built container.
162+
"""
163+
URI_MAP = prediction._SERVING_CONTAINER_URI_MAP
164+
DOCS_URI_MESSAGE = (
165+
f"See {prediction._SERVING_CONTAINER_DOCUMENTATION_URL} "
166+
"for complete list of supported containers"
167+
)
168+
169+
# If region not provided, use initializer location
170+
region = region or initializer.global_config.location
171+
region = region.split("-", 1)[0]
172+
framework = framework.lower()
173+
174+
if not URI_MAP.get(region):
175+
raise ValueError(
176+
f"Unsupported container region `{region}`, supported regions are "
177+
f"{', '.join(URI_MAP.keys())}. "
178+
f"{DOCS_URI_MESSAGE}"
179+
)
180+
181+
if not URI_MAP[region].get(framework):
182+
raise ValueError(
183+
f"No containers found for framework `{framework}`. Supported frameworks are "
184+
f"{', '.join(URI_MAP[region].keys())} {DOCS_URI_MESSAGE}"
185+
)
186+
187+
if not URI_MAP[region][framework].get(accelerator):
188+
raise ValueError(
189+
f"{framework} containers do not support `{accelerator}` accelerator. Supported accelerators "
190+
f"are {', '.join(URI_MAP[region][framework].keys())}. {DOCS_URI_MESSAGE}"
191+
)
192+
193+
framework_version = version.Version(framework_version)
194+
available_version_list = [
195+
version.Version(available_version)
196+
for available_version in URI_MAP[region][framework][accelerator].keys()
197+
]
198+
try:
199+
closest_version = min(
200+
[
201+
available_version
202+
for available_version in available_version_list
203+
if available_version >= framework_version
204+
# manually implement Version.major for packaging < 20.0
205+
and available_version._version.release[0]
206+
== framework_version._version.release[0]
207+
]
208+
)
209+
except ValueError:
210+
raise ValueError(
211+
f"You are using `{framework}` version `{framework_version}`. "
212+
f"Vertex pre-built containers support up to `{framework}` version "
213+
f"`{max(available_version_list)}` and don't assume forward compatibility. "
214+
f"Please build your own custom container. {DOCS_URI_MESSAGE}"
215+
) from None
216+
217+
if closest_version != framework_version:
218+
warnings.warn(
219+
f"No exact match for `{framework}` version `{framework_version}`. "
220+
f"Pre-built container for `{framework}` version `{closest_version}` is used. "
221+
f"{DOCS_URI_MESSAGE}"
222+
)
223+
224+
final_uri = URI_MAP[region][framework][accelerator].get(str(closest_version))
225+
226+
return final_uri

0 commit comments

Comments
 (0)