|
1 |
| -# Copyright 2021 Google LLC |
| 1 | +# Copyright 2022 Google LLC |
2 | 2 | #
|
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | # you may not use this file except in compliance with the License.
|
|
14 | 14 |
|
15 | 15 | import re
|
16 | 16 | from typing import Optional
|
| 17 | +import warnings |
17 | 18 |
|
18 |
| -from google.cloud.aiplatform.constants import prediction |
19 | 19 | from google.cloud.aiplatform import initializer
|
| 20 | +from google.cloud.aiplatform.constants import prediction |
| 21 | +from packaging import version |
20 | 22 |
|
21 | 23 |
|
22 | 24 | def get_prebuilt_prediction_container_uri(
|
@@ -122,3 +124,103 @@ def is_prebuilt_prediction_container_uri(image_uri: str) -> bool:
|
122 | 124 | If the image is prebuilt by Vertex AI prediction.
|
123 | 125 | """
|
124 | 126 | return re.fullmatch(prediction.CONTAINER_URI_REGEX, image_uri) is not None
|
| 127 | + |
| 128 | + |
| 129 | +# TODO(b/264191784) Deduplicate this method |
| 130 | +def _get_closest_match_prebuilt_container_uri( |
| 131 | + framework: str, |
| 132 | + framework_version: str, |
| 133 | + region: Optional[str] = None, |
| 134 | + accelerator: str = "cpu", |
| 135 | +) -> str: |
| 136 | + """Return a pre-built container uri that is suitable for a specific framework and version. |
| 137 | +
|
| 138 | + If there is no exact match for the given version, the closest one that is |
| 139 | + higher than the input version will be used. |
| 140 | +
|
| 141 | + Args: |
| 142 | + framework (str): |
| 143 | + Required. The ML framework of the pre-built container. For example, |
| 144 | + `"tensorflow"`, `"xgboost"`, or `"sklearn"` |
| 145 | + framework_version (str): |
| 146 | + Required. The version of the specified ML framework as a string. |
| 147 | + region (str): |
| 148 | + Optional. AI region or multi-region. Used to select the correct |
| 149 | + Artifact Registry multi-region repository and reduce latency. |
| 150 | + Must start with `"us"`, `"asia"` or `"europe"`. |
| 151 | + Default is location set by `aiplatform.init()`. |
| 152 | + accelerator (str): |
| 153 | + Optional. The type of accelerator support provided by container. For |
| 154 | + example: `"cpu"` or `"gpu"` |
| 155 | + Default is `"cpu"`. |
| 156 | +
|
| 157 | + Returns: |
| 158 | + A string representing the pre-built container uri. |
| 159 | +
|
| 160 | + Raises: |
| 161 | + ValueError: If the framework doesn't have suitable pre-built container. |
| 162 | + """ |
| 163 | + URI_MAP = prediction._SERVING_CONTAINER_URI_MAP |
| 164 | + DOCS_URI_MESSAGE = ( |
| 165 | + f"See {prediction._SERVING_CONTAINER_DOCUMENTATION_URL} " |
| 166 | + "for complete list of supported containers" |
| 167 | + ) |
| 168 | + |
| 169 | + # If region not provided, use initializer location |
| 170 | + region = region or initializer.global_config.location |
| 171 | + region = region.split("-", 1)[0] |
| 172 | + framework = framework.lower() |
| 173 | + |
| 174 | + if not URI_MAP.get(region): |
| 175 | + raise ValueError( |
| 176 | + f"Unsupported container region `{region}`, supported regions are " |
| 177 | + f"{', '.join(URI_MAP.keys())}. " |
| 178 | + f"{DOCS_URI_MESSAGE}" |
| 179 | + ) |
| 180 | + |
| 181 | + if not URI_MAP[region].get(framework): |
| 182 | + raise ValueError( |
| 183 | + f"No containers found for framework `{framework}`. Supported frameworks are " |
| 184 | + f"{', '.join(URI_MAP[region].keys())} {DOCS_URI_MESSAGE}" |
| 185 | + ) |
| 186 | + |
| 187 | + if not URI_MAP[region][framework].get(accelerator): |
| 188 | + raise ValueError( |
| 189 | + f"{framework} containers do not support `{accelerator}` accelerator. Supported accelerators " |
| 190 | + f"are {', '.join(URI_MAP[region][framework].keys())}. {DOCS_URI_MESSAGE}" |
| 191 | + ) |
| 192 | + |
| 193 | + framework_version = version.Version(framework_version) |
| 194 | + available_version_list = [ |
| 195 | + version.Version(available_version) |
| 196 | + for available_version in URI_MAP[region][framework][accelerator].keys() |
| 197 | + ] |
| 198 | + try: |
| 199 | + closest_version = min( |
| 200 | + [ |
| 201 | + available_version |
| 202 | + for available_version in available_version_list |
| 203 | + if available_version >= framework_version |
| 204 | + # manually implement Version.major for packaging < 20.0 |
| 205 | + and available_version._version.release[0] |
| 206 | + == framework_version._version.release[0] |
| 207 | + ] |
| 208 | + ) |
| 209 | + except ValueError: |
| 210 | + raise ValueError( |
| 211 | + f"You are using `{framework}` version `{framework_version}`. " |
| 212 | + f"Vertex pre-built containers support up to `{framework}` version " |
| 213 | + f"`{max(available_version_list)}` and don't assume forward compatibility. " |
| 214 | + f"Please build your own custom container. {DOCS_URI_MESSAGE}" |
| 215 | + ) from None |
| 216 | + |
| 217 | + if closest_version != framework_version: |
| 218 | + warnings.warn( |
| 219 | + f"No exact match for `{framework}` version `{framework_version}`. " |
| 220 | + f"Pre-built container for `{framework}` version `{closest_version}` is used. " |
| 221 | + f"{DOCS_URI_MESSAGE}" |
| 222 | + ) |
| 223 | + |
| 224 | + final_uri = URI_MAP[region][framework][accelerator].get(str(closest_version)) |
| 225 | + |
| 226 | + return final_uri |
0 commit comments