Skip to content

Commit be0b7e1

Browse files
authored
feat: Support custom containers in CustomJob.from_local_script (#1483)
* feat: update from_local_script() * fix from_local_script() * fix: update unit tests * fix lint failed * fix: add system test * fix: from_local_script * fix: from_local_script * fix: system test
1 parent 46aa9b5 commit be0b7e1

File tree

5 files changed

+296
-45
lines changed

5 files changed

+296
-45
lines changed

google/cloud/aiplatform/jobs.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,12 @@ def from_local_script(
12651265
script_path (str):
12661266
Required. Local path to training script.
12671267
container_uri (str):
1268-
Required: Uri of the training container image to use for custom job.
1268+
Required. Uri of the training container image to use for custom job.
1269+
Support images in Artifact Registry, Container Registry, or Docker Hub.
1270+
Vertex AI provides a wide range of executor images with pre-installed
1271+
packages to meet users' various use cases. See the list of `pre-built containers
1272+
for training <https://cloud.google.com/vertex-ai/docs/training/pre-built-containers>`.
1273+
If not using image from this list, please make sure python3 and pip3 are installed in your container.
12691274
args (Optional[Sequence[str]]):
12701275
Optional. Command line arguments to be passed to the Python task.
12711276
requirements (Sequence[str]):
@@ -1388,7 +1393,10 @@ def from_local_script(
13881393
spec["container_spec"] = {
13891394
"image_uri": reduction_server_container_uri,
13901395
}
1391-
else:
1396+
## check if the container is pre-built
1397+
elif ("docker.pkg.dev/vertex-ai/" in container_uri) or (
1398+
"gcr.io/cloud-aiplatform/" in container_uri
1399+
):
13921400
spec["python_package_spec"] = {
13931401
"executor_image_uri": container_uri,
13941402
"python_module": python_packager.module_name,
@@ -1403,6 +1411,30 @@ def from_local_script(
14031411
{"name": key, "value": value}
14041412
for key, value in environment_variables.items()
14051413
]
1414+
else:
1415+
command = [
1416+
"sh",
1417+
"-c",
1418+
"pip install --upgrade pip && "
1419+
+ f"pip3 install -q --user {package_gcs_uri} && ".replace(
1420+
"gs://", "/gcs/"
1421+
)
1422+
+ f"python3 -m {python_packager.module_name}",
1423+
]
1424+
1425+
spec["container_spec"] = {
1426+
"image_uri": container_uri,
1427+
"command": command,
1428+
}
1429+
1430+
if args:
1431+
spec["container_spec"]["args"] = args
1432+
1433+
if environment_variables:
1434+
spec["container_spec"]["env"] = [
1435+
{"name": key, "value": value}
1436+
for key, value in environment_variables.items()
1437+
]
14061438

14071439
return cls(
14081440
display_name=display_name,
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2022 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
20+
import pytest
21+
22+
from google.cloud import aiplatform
23+
from google.cloud.aiplatform.compat.types import job_state as gca_job_state
24+
from tests.system.aiplatform import e2e_base
25+
26+
_PREBUILT_CONTAINER_IMAGE = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest"
27+
_CUSTOM_CONTAINER_IMAGE = "python:3.8"
28+
29+
_DIR_NAME = os.path.dirname(os.path.abspath(__file__))
30+
_LOCAL_TRAINING_SCRIPT_PATH = os.path.join(
31+
_DIR_NAME, "test_resources/custom_job_script.py"
32+
)
33+
34+
35+
@pytest.mark.usefixtures(
36+
"prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources"
37+
)
38+
class TestCustomJob(e2e_base.TestEndToEnd):
39+
40+
_temp_prefix = "temp-vertex-sdk-custom-job"
41+
42+
def test_from_local_script_prebuilt_container(self, shared_state):
43+
shared_state["resources"] = []
44+
45+
aiplatform.init(
46+
project=e2e_base._PROJECT,
47+
location=e2e_base._LOCATION,
48+
staging_bucket=shared_state["staging_bucket_name"],
49+
)
50+
51+
display_name = self._make_display_name("custom-job")
52+
53+
custom_job = aiplatform.CustomJob.from_local_script(
54+
display_name=display_name,
55+
script_path=_LOCAL_TRAINING_SCRIPT_PATH,
56+
container_uri=_PREBUILT_CONTAINER_IMAGE,
57+
)
58+
custom_job.run()
59+
60+
shared_state["resources"].append(custom_job)
61+
62+
assert custom_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
63+
64+
def test_from_local_script_custom_container(self, shared_state):
65+
66+
aiplatform.init(
67+
project=e2e_base._PROJECT,
68+
location=e2e_base._LOCATION,
69+
staging_bucket=shared_state["staging_bucket_name"],
70+
)
71+
72+
display_name = self._make_display_name("custom-job")
73+
74+
custom_job = aiplatform.CustomJob.from_local_script(
75+
display_name=display_name,
76+
script_path=_LOCAL_TRAINING_SCRIPT_PATH,
77+
container_uri=_CUSTOM_CONTAINER_IMAGE,
78+
)
79+
custom_job.run()
80+
81+
shared_state["resources"].append(custom_job)
82+
83+
assert custom_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2022 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
print("Test CustomJob script.")

0 commit comments

Comments
 (0)