Skip to content

Commit 87dd5c0

Browse files
feat: Support custom predictor Docker image builds on non-x86 architectures (#2115)
* Enforce Linux `x86_64` Docker image builds Under the assumption that only machine types with x86 processors are supported for prediction and custom training. * Extract Docker build platform arg to method parameter While currently only x86 processors are supported for prediction and custom training, this will allow users to control this behavior should that ever change in the future. Additionally, it allows users to, e.g., override the `TARGETOS` component of the `TARGETPLATFORM`. * Change platform default arg to `None` To enforce the flag is set by users as opposed to providing a universal default. - See: #2115 (comment) * Make platform configurable from `LocalModel.build_cpr_model()` To enable the flag to be set by users (e.g., to build images on non-x86 architectures). * Fix docstring for `platform` param Resolves: - #2115 (comment) - #2115 (comment) * Test platform parameter in `test_build_cpr_model_upload_and_deploy()` Resolves: - #2115 (review) * Fix tests Resolves (partially): - #2115 (comment) * Test specifying platform in local model builds Resolves (partially): - #2115 (comment) * Test other platform strings in local model builds --------- Co-authored-by: Chun-Hsiang Wang <[email protected]>
1 parent ba9a314 commit 87dd5c0

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

google/cloud/aiplatform/docker_utils/build.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def build_image(
418418
pip_command: str = "pip",
419419
python_command: str = "python",
420420
no_cache: bool = True,
421+
platform: Optional[str] = None,
421422
**kwargs,
422423
) -> Image:
423424
"""Builds a Docker image.
@@ -459,6 +460,10 @@ def build_image(
459460
reduces the image building time. See
460461
https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#leverage-build-cache
461462
for more details.
463+
platform (str):
464+
Optional. The target platform for the Docker image build. See
465+
https://docs.docker.com/build/building/multi-platform/#building-multi-platform-images
466+
for more details.
462467
**kwargs:
463468
Other arguments to pass to underlying method that generates the Dockerfile.
464469
@@ -472,9 +477,14 @@ def build_image(
472477

473478
tag_options = ["-t", output_image_name]
474479
cache_args = ["--no-cache"] if no_cache else []
480+
platform_args = ["--platform", platform] if platform is not None else []
475481

476482
command = (
477-
["docker", "build"] + cache_args + tag_options + ["--rm", "-f-", host_workdir]
483+
["docker", "build"]
484+
+ cache_args
485+
+ platform_args
486+
+ tag_options
487+
+ ["--rm", "-f-", host_workdir]
478488
)
479489

480490
requirements_relative_path = _get_relative_path_to_workdir(

google/cloud/aiplatform/prediction/local_model.py

+7
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def build_cpr_model(
246246
requirements_path: Optional[str] = None,
247247
extra_packages: Optional[List[str]] = None,
248248
no_cache: bool = False,
249+
platform: Optional[str] = None,
249250
) -> "LocalModel":
250251
"""Builds a local model from a custom predictor.
251252
@@ -274,6 +275,7 @@ def build_cpr_model(
274275
predictor=$CUSTOM_PREDICTOR_CLASS,
275276
requirements_path="./user_src_dir/requirements.txt",
276277
extra_packages=["./user_src_dir/user_code/custom_package.tar.gz"],
278+
platform="linux/amd64", # i.e., if you're building on a non-x86 machine
277279
)
278280
279281
In the built image, user provided files will be copied as follows:
@@ -340,6 +342,10 @@ def build_cpr_model(
340342
reduces the image building time. See
341343
https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#leverage-build-cache
342344
for more details.
345+
platform (str):
346+
Optional. The target platform for the Docker image build. See
347+
https://docs.docker.com/build/building/multi-platform/#building-multi-platform-images
348+
for more details.
343349
344350
Returns:
345351
local model: Instantiated representation of the local model.
@@ -391,6 +397,7 @@ def build_cpr_model(
391397
pip_command="pip3" if is_prebuilt_prediction_image else "pip",
392398
python_command="python3" if is_prebuilt_prediction_image else "python",
393399
no_cache=no_cache,
400+
platform=platform,
394401
)
395402

396403
container_spec = gca_model_compat.ModelContainerSpec(

tests/system/aiplatform/test_prediction_cpr.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class TestPredictionCpr(e2e_base.TestEndToEnd):
4949

5050
_temp_prefix = "temp-vertex-sdk-e2e-prediction-cpr"
5151

52-
def test_build_cpr_model_upload_and_deploy(self, shared_state, caplog):
52+
@pytest.mark.parametrize("platform", [None, "linux/amd64"])
53+
def test_build_cpr_model_upload_and_deploy(self, shared_state, caplog, platform):
5354
"""Creates a CPR model from custom predictor, uploads it and deploys."""
5455

5556
caplog.set_level(logging.INFO)
@@ -61,6 +62,7 @@ def test_build_cpr_model_upload_and_deploy(self, shared_state, caplog):
6162
_IMAGE_URI,
6263
predictor=SklearnPredictor,
6364
requirements_path=os.path.join(_USER_CODE_DIR, _REQUIREMENTS_FILE),
65+
platform=platform,
6466
)
6567

6668
with local_model.deploy_to_local_endpoint(

tests/unit/aiplatform/test_prediction.py

+70
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,7 @@ class {predictor_class}:
13041304
pip_command="pip",
13051305
python_command="python",
13061306
no_cache=False,
1307+
platform=None,
13071308
)
13081309

13091310
def test_build_cpr_model_fails_handler_is_none(
@@ -1418,6 +1419,7 @@ class {handler_class}:
14181419
pip_command="pip",
14191420
python_command="python",
14201421
no_cache=False,
1422+
platform=None,
14211423
)
14221424

14231425
def test_build_cpr_model_with_custom_handler_and_predictor_is_none(
@@ -1472,6 +1474,7 @@ class {handler_class}:
14721474
pip_command="pip",
14731475
python_command="python",
14741476
no_cache=False,
1477+
platform=None,
14751478
)
14761479

14771480
def test_build_cpr_model_creates_and_get_localmodel_base_is_prebuilt(
@@ -1527,6 +1530,7 @@ class {predictor_class}:
15271530
pip_command="pip3",
15281531
python_command="python3",
15291532
no_cache=False,
1533+
platform=None,
15301534
)
15311535

15321536
def test_build_cpr_model_creates_and_get_localmodel_with_requirements_path(
@@ -1584,6 +1588,7 @@ class {predictor_class}:
15841588
pip_command="pip",
15851589
python_command="python",
15861590
no_cache=False,
1591+
platform=None,
15871592
)
15881593

15891594
def test_build_cpr_model_creates_and_get_localmodel_with_extra_packages(
@@ -1641,6 +1646,7 @@ class {predictor_class}:
16411646
pip_command="pip",
16421647
python_command="python",
16431648
no_cache=False,
1649+
platform=None,
16441650
)
16451651

16461652
def test_build_cpr_model_creates_and_get_localmodel_no_cache(
@@ -1695,6 +1701,70 @@ class {predictor_class}:
16951701
pip_command="pip",
16961702
python_command="python",
16971703
no_cache=no_cache,
1704+
platform=None,
1705+
)
1706+
1707+
@pytest.mark.parametrize(
1708+
"platform",
1709+
[
1710+
None,
1711+
"linux/amd64",
1712+
"some_arbitrary_platform_value_that_will_by_validated_by_docker_build_command",
1713+
],
1714+
)
1715+
def test_build_cpr_model_creates_and_get_localmodel_platform(
1716+
self,
1717+
tmp_path,
1718+
inspect_source_from_class_mock_predictor_only,
1719+
is_prebuilt_prediction_container_uri_is_false_mock,
1720+
build_image_mock,
1721+
platform,
1722+
):
1723+
src_dir = tmp_path / _TEST_SRC_DIR
1724+
src_dir.mkdir()
1725+
predictor = src_dir / _TEST_PREDICTOR_FILE
1726+
predictor.write_text(
1727+
textwrap.dedent(
1728+
"""
1729+
class {predictor_class}:
1730+
pass
1731+
"""
1732+
).format(predictor_class=_TEST_PREDICTOR_CLASS)
1733+
)
1734+
my_predictor = self._load_module(_TEST_PREDICTOR_CLASS, str(predictor))
1735+
1736+
local_model = LocalModel.build_cpr_model(
1737+
str(src_dir), _TEST_OUTPUT_IMAGE, predictor=my_predictor, platform=platform
1738+
)
1739+
1740+
assert local_model.serving_container_spec.image_uri == _TEST_OUTPUT_IMAGE
1741+
assert local_model.serving_container_spec.predict_route == DEFAULT_PREDICT_ROUTE
1742+
assert local_model.serving_container_spec.health_route == DEFAULT_HEALTH_ROUTE
1743+
inspect_source_from_class_mock_predictor_only.assert_called_once_with(
1744+
my_predictor, str(src_dir)
1745+
)
1746+
is_prebuilt_prediction_container_uri_is_false_mock.assert_called_once_with(
1747+
_DEFAULT_BASE_IMAGE
1748+
)
1749+
build_image_mock.assert_called_once_with(
1750+
_DEFAULT_BASE_IMAGE,
1751+
str(src_dir),
1752+
_TEST_OUTPUT_IMAGE,
1753+
python_module=_DEFAULT_PYTHON_MODULE,
1754+
requirements_path=None,
1755+
extra_requirements=_DEFAULT_SDK_REQUIREMENTS,
1756+
extra_packages=None,
1757+
exposed_ports=[DEFAULT_HTTP_PORT],
1758+
environment_variables={
1759+
"HANDLER_MODULE": _DEFAULT_HANDLER_MODULE,
1760+
"HANDLER_CLASS": _DEFAULT_HANDLER_CLASS,
1761+
"PREDICTOR_MODULE": f"{_TEST_SRC_DIR}.{_TEST_PREDICTOR_FILE_STEM}",
1762+
"PREDICTOR_CLASS": _TEST_PREDICTOR_CLASS,
1763+
},
1764+
pip_command="pip",
1765+
python_command="python",
1766+
no_cache=False,
1767+
platform=platform,
16981768
)
16991769

17001770
def test_deploy_to_local_endpoint(

0 commit comments

Comments
 (0)