Skip to content

Commit f5be0b5

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add Persistent Resource ID parameter to Custom Job form_local_script, run, and submit methods.
PiperOrigin-RevId: 622310810
1 parent 8c6ddf5 commit f5be0b5

File tree

2 files changed

+140
-2
lines changed

2 files changed

+140
-2
lines changed

google/cloud/aiplatform/jobs.py

+38
Original file line numberDiff line numberDiff line change
@@ -1923,6 +1923,7 @@ def from_local_script(
19231923
labels: Optional[Dict[str, str]] = None,
19241924
encryption_spec_key_name: Optional[str] = None,
19251925
staging_bucket: Optional[str] = None,
1926+
persistent_resource_id: Optional[str] = None,
19261927
) -> "CustomJob":
19271928
"""Configures a custom job from a local script.
19281929
@@ -2026,6 +2027,13 @@ def from_local_script(
20262027
staging_bucket (str):
20272028
Optional. Bucket for produced custom job artifacts. Overrides
20282029
staging_bucket set in aiplatform.init.
2030+
persistent_resource_id (str):
2031+
Optional. The ID of the PersistentResource in the same Project
2032+
and Location. If this is specified, the job will be run on
2033+
existing machines held by the PersistentResource instead of
2034+
on-demand short-live machines. The network, CMEK, and node pool
2035+
configs on the job should be consistent with those on the
2036+
PersistentResource, otherwise, the job will be rejected.
20292037
20302038
Raises:
20312039
RuntimeError: If staging bucket was not set using aiplatform.init
@@ -2171,6 +2179,7 @@ def from_local_script(
21712179
labels=labels,
21722180
encryption_spec_key_name=encryption_spec_key_name,
21732181
staging_bucket=staging_bucket,
2182+
persistent_resource_id=persistent_resource_id,
21742183
)
21752184

21762185
if enable_autolog:
@@ -2191,6 +2200,7 @@ def run(
21912200
sync: bool = True,
21922201
create_request_timeout: Optional[float] = None,
21932202
disable_retries: bool = False,
2203+
persistent_resource_id: Optional[str] = None,
21942204
) -> None:
21952205
"""Run this configured CustomJob.
21962206
@@ -2252,6 +2262,13 @@ def run(
22522262
Indicates if the job should retry for internal errors after the
22532263
job starts running. If True, overrides
22542264
`restart_job_on_worker_restart` to False.
2265+
persistent_resource_id (str):
2266+
Optional. The ID of the PersistentResource in the same Project
2267+
and Location. If this is specified, the job will be run on
2268+
existing machines held by the PersistentResource instead of
2269+
on-demand short-live machines. The network, CMEK, and node pool
2270+
configs on the job should be consistent with those on the
2271+
PersistentResource, otherwise, the job will be rejected.
22552272
"""
22562273
network = network or initializer.global_config.network
22572274
service_account = service_account or initializer.global_config.service_account
@@ -2268,6 +2285,7 @@ def run(
22682285
sync=sync,
22692286
create_request_timeout=create_request_timeout,
22702287
disable_retries=disable_retries,
2288+
persistent_resource_id=persistent_resource_id,
22712289
)
22722290

22732291
@base.optional_sync()
@@ -2284,6 +2302,7 @@ def _run(
22842302
sync: bool = True,
22852303
create_request_timeout: Optional[float] = None,
22862304
disable_retries: bool = False,
2305+
persistent_resource_id: Optional[str] = None,
22872306
) -> None:
22882307
"""Helper method to ensure network synchronization and to run the configured CustomJob.
22892308
@@ -2343,6 +2362,13 @@ def _run(
23432362
Indicates if the job should retry for internal errors after the
23442363
job starts running. If True, overrides
23452364
`restart_job_on_worker_restart` to False.
2365+
persistent_resource_id (str):
2366+
Optional. The ID of the PersistentResource in the same Project
2367+
and Location. If this is specified, the job will be run on
2368+
existing machines held by the PersistentResource instead of
2369+
on-demand short-live machines. The network, CMEK, and node pool
2370+
configs on the job should be consistent with those on the
2371+
PersistentResource, otherwise, the job will be rejected.
23462372
"""
23472373
self.submit(
23482374
service_account=service_account,
@@ -2355,6 +2381,7 @@ def _run(
23552381
tensorboard=tensorboard,
23562382
create_request_timeout=create_request_timeout,
23572383
disable_retries=disable_retries,
2384+
persistent_resource_id=persistent_resource_id,
23582385
)
23592386

23602387
self._block_until_complete()
@@ -2372,6 +2399,7 @@ def submit(
23722399
tensorboard: Optional[str] = None,
23732400
create_request_timeout: Optional[float] = None,
23742401
disable_retries: bool = False,
2402+
persistent_resource_id: Optional[str] = None,
23752403
) -> None:
23762404
"""Submit the configured CustomJob.
23772405
@@ -2428,6 +2456,13 @@ def submit(
24282456
Indicates if the job should retry for internal errors after the
24292457
job starts running. If True, overrides
24302458
`restart_job_on_worker_restart` to False.
2459+
persistent_resource_id (str):
2460+
Optional. The ID of the PersistentResource in the same Project
2461+
and Location. If this is specified, the job will be run on
2462+
existing machines held by the PersistentResource instead of
2463+
on-demand short-live machines. The network, CMEK, and node pool
2464+
configs on the job should be consistent with those on the
2465+
PersistentResource, otherwise, the job will be rejected.
24312466
24322467
Raises:
24332468
ValueError:
@@ -2464,6 +2499,9 @@ def submit(
24642499
if tensorboard:
24652500
self._gca_resource.job_spec.tensorboard = tensorboard
24662501

2502+
if persistent_resource_id:
2503+
self._gca_resource.job_spec.persistent_resource_id = persistent_resource_id
2504+
24672505
# TODO(b/275105711) Update implementation after experiment/run in the proto
24682506
if experiment:
24692507
# short-term solution to set experiment/experimentRun in SDK

tests/unit/aiplatform/test_custom_job_persistent_resource.py

+102-2
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@
2222
from google.cloud import aiplatform
2323
from google.cloud.aiplatform import jobs
2424
from google.cloud.aiplatform.compat.services import job_service_client_v1
25+
from google.cloud.aiplatform.compat.types import (
26+
custom_job as gca_custom_job_compat,
27+
)
2528
from google.cloud.aiplatform.compat.types import custom_job_v1
2629
from google.cloud.aiplatform.compat.types import encryption_spec_v1
2730
from google.cloud.aiplatform.compat.types import io_v1
28-
from google.cloud.aiplatform.compat.types import job_state_v1 as gca_job_state_compat
31+
from google.cloud.aiplatform.compat.types import (
32+
job_state_v1 as gca_job_state_compat,
33+
)
2934
import constants as test_constants
3035
import pytest
3136

@@ -71,6 +76,11 @@
7176

7277
_TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS
7378

79+
_TEST_PYTHON_PACKAGE_SPEC = gca_custom_job_compat.PythonPackageSpec(
80+
executor_image_uri=_TEST_PREBUILT_CONTAINER_IMAGE,
81+
package_uris=[test_constants.TrainingJobConstants._TEST_OUTPUT_PYTHON_PACKAGE_PATH],
82+
python_module=test_constants.TrainingJobConstants._TEST_MODULE_NAME,
83+
)
7484

7585
# Persistent Resource
7686
_TEST_PERSISTENT_RESOURCE_ID = "test-persistent-resource-1"
@@ -212,7 +222,6 @@ def test_submit_custom_job_with_persistent_resource(
212222
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
213223
base_output_dir=_TEST_BASE_OUTPUT_DIR,
214224
labels=_TEST_LABELS,
215-
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
216225
)
217226

218227
job.submit(
@@ -222,6 +231,7 @@ def test_submit_custom_job_with_persistent_resource(
222231
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
223232
create_request_timeout=None,
224233
disable_retries=_TEST_DISABLE_RETRIES,
234+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
225235
)
226236

227237
job.wait_for_resource_creation()
@@ -243,3 +253,93 @@ def test_submit_custom_job_with_persistent_resource(
243253
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
244254
)
245255
assert job.network == _TEST_NETWORK
256+
257+
@pytest.mark.parametrize("sync", [True, False])
258+
def test_run_custom_job_with_persistent_resource(
259+
self, create_custom_job_mock, get_custom_job_mock, sync
260+
):
261+
262+
aiplatform.init(
263+
project=_TEST_PROJECT,
264+
location=_TEST_LOCATION,
265+
staging_bucket=_TEST_STAGING_BUCKET,
266+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
267+
)
268+
269+
job = jobs.CustomJob(
270+
display_name=_TEST_DISPLAY_NAME,
271+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
272+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
273+
labels=_TEST_LABELS,
274+
)
275+
276+
job.run(
277+
service_account=_TEST_SERVICE_ACCOUNT,
278+
network=_TEST_NETWORK,
279+
timeout=_TEST_TIMEOUT,
280+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
281+
create_request_timeout=None,
282+
disable_retries=_TEST_DISABLE_RETRIES,
283+
sync=sync,
284+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
285+
)
286+
287+
job.wait_for_resource_creation()
288+
289+
assert job.resource_name == _TEST_CUSTOM_JOB_NAME
290+
291+
job.wait()
292+
293+
expected_custom_job = _get_custom_job_proto()
294+
295+
create_custom_job_mock.assert_called_once_with(
296+
parent=_TEST_PARENT,
297+
custom_job=expected_custom_job,
298+
timeout=None,
299+
)
300+
301+
assert job.job_spec == expected_custom_job.job_spec
302+
assert (
303+
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
304+
)
305+
assert job.network == _TEST_NETWORK
306+
307+
@pytest.mark.usefixtures("mock_python_package_to_gcs")
308+
@pytest.mark.parametrize("sync", [True, False])
309+
def test_from_local_script_custom_job_with_persistent_resource(
310+
self, create_custom_job_mock, get_custom_job_mock, sync
311+
):
312+
313+
aiplatform.init(
314+
project=_TEST_PROJECT,
315+
location=_TEST_LOCATION,
316+
staging_bucket=_TEST_STAGING_BUCKET,
317+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
318+
)
319+
320+
job = jobs.CustomJob.from_local_script(
321+
display_name=_TEST_DISPLAY_NAME,
322+
script_path=test_constants.TrainingJobConstants._TEST_LOCAL_SCRIPT_FILE_NAME,
323+
container_uri=_TEST_PREBUILT_CONTAINER_IMAGE,
324+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
325+
labels=_TEST_LABELS,
326+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
327+
)
328+
329+
assert (
330+
job.job_spec.worker_pool_specs[0].python_package_spec
331+
== _TEST_PYTHON_PACKAGE_SPEC
332+
)
333+
334+
job.run(sync=sync)
335+
336+
job.wait_for_resource_creation()
337+
338+
assert job.resource_name == _TEST_CUSTOM_JOB_NAME
339+
340+
job.wait()
341+
342+
assert job.job_spec.persistent_resource_id == _TEST_PERSISTENT_RESOURCE_ID
343+
assert (
344+
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
345+
)

0 commit comments

Comments
 (0)