Skip to content

Commit a076191

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for strategy in custom training jobs.
PiperOrigin-RevId: 661428427
1 parent 7404f67 commit a076191

File tree

6 files changed

+405
-3
lines changed

6 files changed

+405
-3
lines changed

google/cloud/aiplatform/jobs.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -2214,6 +2214,7 @@ def run(
22142214
create_request_timeout: Optional[float] = None,
22152215
disable_retries: bool = False,
22162216
persistent_resource_id: Optional[str] = None,
2217+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
22172218
) -> None:
22182219
"""Run this configured CustomJob.
22192220
@@ -2282,6 +2283,8 @@ def run(
22822283
on-demand short-live machines. The network, CMEK, and node pool
22832284
configs on the job should be consistent with those on the
22842285
PersistentResource, otherwise, the job will be rejected.
2286+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
2287+
Optional. Indicates the job scheduling strategy.
22852288
"""
22862289
network = network or initializer.global_config.network
22872290
service_account = service_account or initializer.global_config.service_account
@@ -2299,6 +2302,7 @@ def run(
22992302
create_request_timeout=create_request_timeout,
23002303
disable_retries=disable_retries,
23012304
persistent_resource_id=persistent_resource_id,
2305+
scheduling_strategy=scheduling_strategy,
23022306
)
23032307

23042308
@base.optional_sync()
@@ -2316,6 +2320,7 @@ def _run(
23162320
create_request_timeout: Optional[float] = None,
23172321
disable_retries: bool = False,
23182322
persistent_resource_id: Optional[str] = None,
2323+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
23192324
) -> None:
23202325
"""Helper method to ensure network synchronization and to run the configured CustomJob.
23212326
@@ -2382,6 +2387,8 @@ def _run(
23822387
on-demand short-live machines. The network, CMEK, and node pool
23832388
configs on the job should be consistent with those on the
23842389
PersistentResource, otherwise, the job will be rejected.
2390+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
2391+
Optional. Indicates the job scheduling strategy.
23852392
"""
23862393
self.submit(
23872394
service_account=service_account,
@@ -2395,6 +2402,7 @@ def _run(
23952402
create_request_timeout=create_request_timeout,
23962403
disable_retries=disable_retries,
23972404
persistent_resource_id=persistent_resource_id,
2405+
scheduling_strategy=scheduling_strategy,
23982406
)
23992407

24002408
self._block_until_complete()
@@ -2413,6 +2421,7 @@ def submit(
24132421
create_request_timeout: Optional[float] = None,
24142422
disable_retries: bool = False,
24152423
persistent_resource_id: Optional[str] = None,
2424+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
24162425
) -> None:
24172426
"""Submit the configured CustomJob.
24182427
@@ -2476,6 +2485,8 @@ def submit(
24762485
on-demand short-live machines. The network, CMEK, and node pool
24772486
configs on the job should be consistent with those on the
24782487
PersistentResource, otherwise, the job will be rejected.
2488+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
2489+
Optional. Indicates the job scheduling strategy.
24792490
24802491
Raises:
24812492
ValueError:
@@ -2498,12 +2509,18 @@ def submit(
24982509
if network:
24992510
self._gca_resource.job_spec.network = network
25002511

2501-
if timeout or restart_job_on_worker_restart or disable_retries:
2512+
if (
2513+
timeout
2514+
or restart_job_on_worker_restart
2515+
or disable_retries
2516+
or scheduling_strategy
2517+
):
25022518
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
25032519
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
25042520
timeout=timeout,
25052521
restart_job_on_worker_restart=restart_job_on_worker_restart,
25062522
disable_retries=disable_retries,
2523+
strategy=scheduling_strategy,
25072524
)
25082525

25092526
if enable_web_access:
@@ -2868,6 +2885,7 @@ def run(
28682885
sync: bool = True,
28692886
create_request_timeout: Optional[float] = None,
28702887
disable_retries: bool = False,
2888+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
28712889
) -> None:
28722890
"""Run this configured CustomJob.
28732891
@@ -2916,6 +2934,8 @@ def run(
29162934
Indicates if the job should retry for internal errors after the
29172935
job starts running. If True, overrides
29182936
`restart_job_on_worker_restart` to False.
2937+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
2938+
Optional. Indicates the job scheduling strategy.
29192939
"""
29202940
network = network or initializer.global_config.network
29212941
service_account = service_account or initializer.global_config.service_account
@@ -2930,6 +2950,7 @@ def run(
29302950
sync=sync,
29312951
create_request_timeout=create_request_timeout,
29322952
disable_retries=disable_retries,
2953+
scheduling_strategy=scheduling_strategy,
29332954
)
29342955

29352956
@base.optional_sync()
@@ -2944,6 +2965,7 @@ def _run(
29442965
sync: bool = True,
29452966
create_request_timeout: Optional[float] = None,
29462967
disable_retries: bool = False,
2968+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
29472969
) -> None:
29482970
"""Helper method to ensure network synchronization and to run the configured CustomJob.
29492971
@@ -2990,20 +3012,28 @@ def _run(
29903012
Indicates if the job should retry for internal errors after the
29913013
job starts running. If True, overrides
29923014
`restart_job_on_worker_restart` to False.
3015+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
3016+
Optional. Indicates the job scheduling strategy.
29933017
"""
29943018
if service_account:
29953019
self._gca_resource.trial_job_spec.service_account = service_account
29963020

29973021
if network:
29983022
self._gca_resource.trial_job_spec.network = network
29993023

3000-
if timeout or restart_job_on_worker_restart or disable_retries:
3024+
if (
3025+
timeout
3026+
or restart_job_on_worker_restart
3027+
or disable_retries
3028+
or scheduling_strategy
3029+
):
30013030
duration = duration_pb2.Duration(seconds=timeout) if timeout else None
30023031
self._gca_resource.trial_job_spec.scheduling = (
30033032
gca_custom_job_compat.Scheduling(
30043033
timeout=duration,
30053034
restart_job_on_worker_restart=restart_job_on_worker_restart,
30063035
disable_retries=disable_retries,
3036+
strategy=scheduling_strategy,
30073037
)
30083038
)
30093039

google/cloud/aiplatform/training_jobs.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from google.cloud.aiplatform.compat.types import (
4545
training_pipeline as gca_training_pipeline,
4646
study as gca_study_compat,
47+
custom_job as gca_custom_job_compat,
4748
)
4849

4950
from google.cloud.aiplatform.utils import _timestamped_gcs_dir
@@ -1525,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir(
15251526
tensorboard: Optional[str] = None,
15261527
disable_retries: bool = False,
15271528
persistent_resource_id: Optional[str] = None,
1529+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
15281530
) -> Tuple[Dict, str]:
15291531
"""Prepares training task inputs and output directory for custom job.
15301532
@@ -1582,6 +1584,8 @@ def _prepare_training_task_inputs_and_output_dir(
15821584
on-demand short-live machines. The network, CMEK, and node pool
15831585
configs on the job should be consistent with those on the
15841586
PersistentResource, otherwise, the job will be rejected.
1587+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
1588+
Optional. Indicates the job scheduling strategy.
15851589
15861590
Returns:
15871591
Training task inputs and Output directory for custom job.
@@ -1612,12 +1616,18 @@ def _prepare_training_task_inputs_and_output_dir(
16121616
if persistent_resource_id:
16131617
training_task_inputs["persistent_resource_id"] = persistent_resource_id
16141618

1615-
if timeout or restart_job_on_worker_restart or disable_retries:
1619+
if (
1620+
timeout
1621+
or restart_job_on_worker_restart
1622+
or disable_retries
1623+
or scheduling_strategy
1624+
):
16161625
timeout = f"{timeout}s" if timeout else None
16171626
scheduling = {
16181627
"timeout": timeout,
16191628
"restart_job_on_worker_restart": restart_job_on_worker_restart,
16201629
"disable_retries": disable_retries,
1630+
"strategy": scheduling_strategy,
16211631
}
16221632
training_task_inputs["scheduling"] = scheduling
16231633

@@ -3005,6 +3015,7 @@ def run(
30053015
disable_retries: bool = False,
30063016
persistent_resource_id: Optional[str] = None,
30073017
tpu_topology: Optional[str] = None,
3018+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
30083019
) -> Optional[models.Model]:
30093020
"""Runs the custom training job.
30103021
@@ -3360,6 +3371,8 @@ def run(
33603371
details on the TPU topology, refer to
33613372
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must
33623373
be a supported value for the TPU machine type.
3374+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
3375+
Optional. Indicates the job scheduling strategy.
33633376
33643377
Returns:
33653378
The trained Vertex AI model resource or None if the training
@@ -3424,6 +3437,7 @@ def run(
34243437
create_request_timeout=create_request_timeout,
34253438
disable_retries=disable_retries,
34263439
persistent_resource_id=persistent_resource_id,
3440+
scheduling_strategy=scheduling_strategy,
34273441
)
34283442

34293443
def submit(
@@ -3477,6 +3491,7 @@ def submit(
34773491
disable_retries: bool = False,
34783492
persistent_resource_id: Optional[str] = None,
34793493
tpu_topology: Optional[str] = None,
3494+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
34803495
) -> Optional[models.Model]:
34813496
"""Submits the custom training job without blocking until completion.
34823497
@@ -3777,6 +3792,8 @@ def submit(
37773792
details on the TPU topology, refer to
37783793
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must
37793794
be a supported value for the TPU machine type.
3795+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
3796+
Optional. Indicates the job scheduling strategy.
37803797
37813798
Returns:
37823799
model: The trained Vertex AI Model resource or None if training did not
@@ -3841,6 +3858,7 @@ def submit(
38413858
block=False,
38423859
disable_retries=disable_retries,
38433860
persistent_resource_id=persistent_resource_id,
3861+
scheduling_strategy=scheduling_strategy,
38443862
)
38453863

38463864
@base.optional_sync(construct_object_on_arg="managed_model")
@@ -3888,6 +3906,7 @@ def _run(
38883906
block: Optional[bool] = True,
38893907
disable_retries: bool = False,
38903908
persistent_resource_id: Optional[str] = None,
3909+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
38913910
) -> Optional[models.Model]:
38923911
"""Packages local script and launches training_job.
38933912
@@ -4084,6 +4103,8 @@ def _run(
40844103
on-demand short-live machines. The network, CMEK, and node pool
40854104
configs on the job should be consistent with those on the
40864105
PersistentResource, otherwise, the job will be rejected.
4106+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
4107+
Optional. Indicates the job scheduling strategy.
40874108
40884109
Returns:
40894110
model: The trained Vertex AI Model resource or None if training did not
@@ -4138,6 +4159,7 @@ def _run(
41384159
tensorboard=tensorboard,
41394160
disable_retries=disable_retries,
41404161
persistent_resource_id=persistent_resource_id,
4162+
scheduling_strategy=scheduling_strategy,
41414163
)
41424164

41434165
model = self._run_job(
@@ -4462,6 +4484,7 @@ def run(
44624484
disable_retries: bool = False,
44634485
persistent_resource_id: Optional[str] = None,
44644486
tpu_topology: Optional[str] = None,
4487+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
44654488
) -> Optional[models.Model]:
44664489
"""Runs the custom training job.
44674490
@@ -4755,6 +4778,8 @@ def run(
47554778
details on the TPU topology, refer to
47564779
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology
47574780
must be a supported value for the TPU machine type.
4781+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
4782+
Optional. Indicates the job scheduling strategy.
47584783
47594784
Returns:
47604785
model: The trained Vertex AI Model resource or None if training did not
@@ -4818,6 +4843,7 @@ def run(
48184843
create_request_timeout=create_request_timeout,
48194844
disable_retries=disable_retries,
48204845
persistent_resource_id=persistent_resource_id,
4846+
scheduling_strategy=scheduling_strategy,
48214847
)
48224848

48234849
def submit(
@@ -4871,6 +4897,7 @@ def submit(
48714897
disable_retries: bool = False,
48724898
persistent_resource_id: Optional[str] = None,
48734899
tpu_topology: Optional[str] = None,
4900+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
48744901
) -> Optional[models.Model]:
48754902
"""Submits the custom training job without blocking until completion.
48764903
@@ -5164,6 +5191,8 @@ def submit(
51645191
details on the TPU topology, refer to
51655192
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology
51665193
must be a supported value for the TPU machine type.
5194+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
5195+
Optional. Indicates the job scheduling strategy.
51675196
51685197
Returns:
51695198
model: The trained Vertex AI Model resource or None if training did not
@@ -5227,6 +5256,7 @@ def submit(
52275256
block=False,
52285257
disable_retries=disable_retries,
52295258
persistent_resource_id=persistent_resource_id,
5259+
scheduling_strategy=scheduling_strategy,
52305260
)
52315261

52325262
@base.optional_sync(construct_object_on_arg="managed_model")
@@ -5273,6 +5303,7 @@ def _run(
52735303
block: Optional[bool] = True,
52745304
disable_retries: bool = False,
52755305
persistent_resource_id: Optional[str] = None,
5306+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
52765307
) -> Optional[models.Model]:
52775308
"""Packages local script and launches training_job.
52785309
Args:
@@ -5465,6 +5496,8 @@ def _run(
54655496
on-demand short-live machines. The network, CMEK, and node pool
54665497
configs on the job should be consistent with those on the
54675498
PersistentResource, otherwise, the job will be rejected.
5499+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
5500+
Optional. Indicates the job scheduling strategy.
54685501
54695502
Returns:
54705503
model: The trained Vertex AI Model resource or None if training did not
@@ -5513,6 +5546,7 @@ def _run(
55135546
tensorboard=tensorboard,
55145547
disable_retries=disable_retries,
55155548
persistent_resource_id=persistent_resource_id,
5549+
scheduling_strategy=scheduling_strategy,
55165550
)
55175551

55185552
model = self._run_job(
@@ -7537,6 +7571,7 @@ def run(
75377571
disable_retries: bool = False,
75387572
persistent_resource_id: Optional[str] = None,
75397573
tpu_topology: Optional[str] = None,
7574+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
75407575
) -> Optional[models.Model]:
75417576
"""Runs the custom training job.
75427577
@@ -7831,6 +7866,8 @@ def run(
78317866
details on the TPU topology, refer to
78327867
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology
78337868
must be a supported value for the TPU machine type.
7869+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
7870+
Optional. Indicates the job scheduling strategy.
78347871
78357872
Returns:
78367873
model: The trained Vertex AI Model resource or None if training did not
@@ -7889,6 +7926,7 @@ def run(
78897926
create_request_timeout=create_request_timeout,
78907927
disable_retries=disable_retries,
78917928
persistent_resource_id=persistent_resource_id,
7929+
scheduling_strategy=scheduling_strategy,
78927930
)
78937931

78947932
@base.optional_sync(construct_object_on_arg="managed_model")
@@ -7934,6 +7972,7 @@ def _run(
79347972
create_request_timeout: Optional[float] = None,
79357973
disable_retries: bool = False,
79367974
persistent_resource_id: Optional[str] = None,
7975+
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
79377976
) -> Optional[models.Model]:
79387977
"""Packages local script and launches training_job.
79397978
@@ -8111,6 +8150,8 @@ def _run(
81118150
on-demand short-live machines. The network, CMEK, and node pool
81128151
configs on the job should be consistent with those on the
81138152
PersistentResource, otherwise, the job will be rejected.
8153+
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
8154+
Optional. Indicates the job scheduling strategy.
81148155
81158156
Returns:
81168157
model: The trained Vertex AI Model resource or None if training did not
@@ -8159,6 +8200,7 @@ def _run(
81598200
tensorboard=tensorboard,
81608201
disable_retries=disable_retries,
81618202
persistent_resource_id=persistent_resource_id,
8203+
scheduling_strategy=scheduling_strategy,
81628204
)
81638205

81648206
model = self._run_job(

0 commit comments

Comments
 (0)