Skip to content

Commit 89078e0

Browse files
authored
feat: Added scheduling to CustomTrainingJob, CustomPythonPackageTrainingJob, CustomContainerTrainingJob (#970)
* Added scheduling to customtrainingjob * Added unit tests Fixed tests Fixed test fix: Broken test * Added integration test * Removed comment * Updated e2e tabular test * Fixed lint issue * Simplfied tests * Added more assertions
1 parent c10923b commit 89078e0

File tree

3 files changed

+317
-1
lines changed

3 files changed

+317
-1
lines changed

google/cloud/aiplatform/training_jobs.py

+83
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,8 @@ def _prepare_training_task_inputs_and_output_dir(
13791379
base_output_dir: Optional[str] = None,
13801380
service_account: Optional[str] = None,
13811381
network: Optional[str] = None,
1382+
timeout: Optional[int] = None,
1383+
restart_job_on_worker_restart: bool = False,
13821384
enable_web_access: bool = False,
13831385
tensorboard: Optional[str] = None,
13841386
) -> Tuple[Dict, str]:
@@ -1398,6 +1400,13 @@ def _prepare_training_task_inputs_and_output_dir(
13981400
should be peered. For example, projects/12345/global/networks/myVPC.
13991401
Private services access must already be configured for the network.
14001402
If left unspecified, the job is not peered with any network.
1403+
timeout (int):
1404+
The maximum job running time in seconds. The default is 7 days.
1405+
restart_job_on_worker_restart (bool):
1406+
Restarts the entire CustomJob if a worker
1407+
gets restarted. This feature can be used by
1408+
distributed training jobs that are not resilient
1409+
to workers leaving and joining a job.
14011410
enable_web_access (bool):
14021411
Whether you want Vertex AI to enable interactive shell access
14031412
to training containers.
@@ -1442,6 +1451,14 @@ def _prepare_training_task_inputs_and_output_dir(
14421451
if enable_web_access:
14431452
training_task_inputs["enable_web_access"] = enable_web_access
14441453

1454+
if timeout or restart_job_on_worker_restart:
1455+
timeout = f"{timeout}s" if timeout else None
1456+
scheduling = {
1457+
"timeout": timeout,
1458+
"restart_job_on_worker_restart": restart_job_on_worker_restart,
1459+
}
1460+
training_task_inputs["scheduling"] = scheduling
1461+
14451462
return training_task_inputs, base_output_dir
14461463

14471464
@property
@@ -1794,6 +1811,8 @@ def run(
17941811
test_filter_split: Optional[str] = None,
17951812
predefined_split_column_name: Optional[str] = None,
17961813
timestamp_split_column_name: Optional[str] = None,
1814+
timeout: Optional[int] = None,
1815+
restart_job_on_worker_restart: bool = False,
17971816
enable_web_access: bool = False,
17981817
tensorboard: Optional[str] = None,
17991818
sync=True,
@@ -2014,6 +2033,13 @@ def run(
20142033
that piece is ignored by the pipeline.
20152034
20162035
Supported only for tabular and time series Datasets.
2036+
timeout (int):
2037+
The maximum job running time in seconds. The default is 7 days.
2038+
restart_job_on_worker_restart (bool):
2039+
Restarts the entire CustomJob if a worker
2040+
gets restarted. This feature can be used by
2041+
distributed training jobs that are not resilient
2042+
to workers leaving and joining a job.
20172043
enable_web_access (bool):
20182044
Whether you want Vertex AI to enable interactive shell access
20192045
to training containers.
@@ -2080,6 +2106,8 @@ def run(
20802106
test_filter_split=test_filter_split,
20812107
predefined_split_column_name=predefined_split_column_name,
20822108
timestamp_split_column_name=timestamp_split_column_name,
2109+
timeout=timeout,
2110+
restart_job_on_worker_restart=restart_job_on_worker_restart,
20832111
enable_web_access=enable_web_access,
20842112
tensorboard=tensorboard,
20852113
reduction_server_container_uri=reduction_server_container_uri
@@ -2117,6 +2145,8 @@ def _run(
21172145
test_filter_split: Optional[str] = None,
21182146
predefined_split_column_name: Optional[str] = None,
21192147
timestamp_split_column_name: Optional[str] = None,
2148+
timeout: Optional[int] = None,
2149+
restart_job_on_worker_restart: bool = False,
21202150
enable_web_access: bool = False,
21212151
tensorboard: Optional[str] = None,
21222152
reduction_server_container_uri: Optional[str] = None,
@@ -2237,6 +2267,13 @@ def _run(
22372267
that piece is ignored by the pipeline.
22382268
22392269
Supported only for tabular and time series Datasets.
2270+
timeout (int):
2271+
The maximum job running time in seconds. The default is 7 days.
2272+
restart_job_on_worker_restart (bool):
2273+
Restarts the entire CustomJob if a worker
2274+
gets restarted. This feature can be used by
2275+
distributed training jobs that are not resilient
2276+
to workers leaving and joining a job.
22402277
enable_web_access (bool):
22412278
Whether you want Vertex AI to enable interactive shell access
22422279
to training containers.
@@ -2309,6 +2346,8 @@ def _run(
23092346
base_output_dir=base_output_dir,
23102347
service_account=service_account,
23112348
network=network,
2349+
timeout=timeout,
2350+
restart_job_on_worker_restart=restart_job_on_worker_restart,
23122351
enable_web_access=enable_web_access,
23132352
tensorboard=tensorboard,
23142353
)
@@ -2598,6 +2637,8 @@ def run(
25982637
test_filter_split: Optional[str] = None,
25992638
predefined_split_column_name: Optional[str] = None,
26002639
timestamp_split_column_name: Optional[str] = None,
2640+
timeout: Optional[int] = None,
2641+
restart_job_on_worker_restart: bool = False,
26012642
enable_web_access: bool = False,
26022643
tensorboard: Optional[str] = None,
26032644
sync=True,
@@ -2811,6 +2852,13 @@ def run(
28112852
that piece is ignored by the pipeline.
28122853
28132854
Supported only for tabular and time series Datasets.
2855+
timeout (int):
2856+
The maximum job running time in seconds. The default is 7 days.
2857+
restart_job_on_worker_restart (bool):
2858+
Restarts the entire CustomJob if a worker
2859+
gets restarted. This feature can be used by
2860+
distributed training jobs that are not resilient
2861+
to workers leaving and joining a job.
28142862
enable_web_access (bool):
28152863
Whether you want Vertex AI to enable interactive shell access
28162864
to training containers.
@@ -2876,6 +2924,8 @@ def run(
28762924
test_filter_split=test_filter_split,
28772925
predefined_split_column_name=predefined_split_column_name,
28782926
timestamp_split_column_name=timestamp_split_column_name,
2927+
timeout=timeout,
2928+
restart_job_on_worker_restart=restart_job_on_worker_restart,
28792929
enable_web_access=enable_web_access,
28802930
tensorboard=tensorboard,
28812931
reduction_server_container_uri=reduction_server_container_uri
@@ -2912,6 +2962,8 @@ def _run(
29122962
test_filter_split: Optional[str] = None,
29132963
predefined_split_column_name: Optional[str] = None,
29142964
timestamp_split_column_name: Optional[str] = None,
2965+
timeout: Optional[int] = None,
2966+
restart_job_on_worker_restart: bool = False,
29152967
enable_web_access: bool = False,
29162968
tensorboard: Optional[str] = None,
29172969
reduction_server_container_uri: Optional[str] = None,
@@ -2965,6 +3017,13 @@ def _run(
29653017
should be peered. For example, projects/12345/global/networks/myVPC.
29663018
Private services access must already be configured for the network.
29673019
If left unspecified, the job is not peered with any network.
3020+
timeout (int):
3021+
The maximum job running time in seconds. The default is 7 days.
3022+
restart_job_on_worker_restart (bool):
3023+
Restarts the entire CustomJob if a worker
3024+
gets restarted. This feature can be used by
3025+
distributed training jobs that are not resilient
3026+
to workers leaving and joining a job.
29683027
bigquery_destination (str):
29693028
The BigQuery project location where the training data is to
29703029
be written to. In the given project a new dataset is created
@@ -3094,6 +3153,8 @@ def _run(
30943153
base_output_dir=base_output_dir,
30953154
service_account=service_account,
30963155
network=network,
3156+
timeout=timeout,
3157+
restart_job_on_worker_restart=restart_job_on_worker_restart,
30973158
enable_web_access=enable_web_access,
30983159
tensorboard=tensorboard,
30993160
)
@@ -5373,6 +5434,8 @@ def run(
53735434
test_filter_split: Optional[str] = None,
53745435
predefined_split_column_name: Optional[str] = None,
53755436
timestamp_split_column_name: Optional[str] = None,
5437+
timeout: Optional[int] = None,
5438+
restart_job_on_worker_restart: bool = False,
53765439
enable_web_access: bool = False,
53775440
tensorboard: Optional[str] = None,
53785441
sync=True,
@@ -5586,6 +5649,13 @@ def run(
55865649
that piece is ignored by the pipeline.
55875650
55885651
Supported only for tabular and time series Datasets.
5652+
timeout (int):
5653+
The maximum job running time in seconds. The default is 7 days.
5654+
restart_job_on_worker_restart (bool):
5655+
Restarts the entire CustomJob if a worker
5656+
gets restarted. This feature can be used by
5657+
distributed training jobs that are not resilient
5658+
to workers leaving and joining a job.
55895659
enable_web_access (bool):
55905660
Whether you want Vertex AI to enable interactive shell access
55915661
to training containers.
@@ -5646,6 +5716,8 @@ def run(
56465716
predefined_split_column_name=predefined_split_column_name,
56475717
timestamp_split_column_name=timestamp_split_column_name,
56485718
bigquery_destination=bigquery_destination,
5719+
timeout=timeout,
5720+
restart_job_on_worker_restart=restart_job_on_worker_restart,
56495721
enable_web_access=enable_web_access,
56505722
tensorboard=tensorboard,
56515723
reduction_server_container_uri=reduction_server_container_uri
@@ -5682,6 +5754,8 @@ def _run(
56825754
predefined_split_column_name: Optional[str] = None,
56835755
timestamp_split_column_name: Optional[str] = None,
56845756
bigquery_destination: Optional[str] = None,
5757+
timeout: Optional[int] = None,
5758+
restart_job_on_worker_restart: bool = False,
56855759
enable_web_access: bool = False,
56865760
tensorboard: Optional[str] = None,
56875761
reduction_server_container_uri: Optional[str] = None,
@@ -5785,6 +5859,13 @@ def _run(
57855859
that piece is ignored by the pipeline.
57865860
57875861
Supported only for tabular and time series Datasets.
5862+
timeout (int):
5863+
The maximum job running time in seconds. The default is 7 days.
5864+
restart_job_on_worker_restart (bool):
5865+
Restarts the entire CustomJob if a worker
5866+
gets restarted. This feature can be used by
5867+
distributed training jobs that are not resilient
5868+
to workers leaving and joining a job.
57885869
enable_web_access (bool):
57895870
Whether you want Vertex AI to enable interactive shell access
57905871
to training containers.
@@ -5851,6 +5932,8 @@ def _run(
58515932
base_output_dir=base_output_dir,
58525933
service_account=service_account,
58535934
network=network,
5935+
timeout=timeout,
5936+
restart_job_on_worker_restart=restart_job_on_worker_restart,
58545937
enable_web_access=enable_web_access,
58555938
tensorboard=tensorboard,
58565939
)

tests/system/aiplatform/test_e2e_tabular.py

+15
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def test_end_to_end_tabular(self, shared_state):
109109
ds,
110110
replica_count=1,
111111
model_display_name=self._make_display_name("custom-housing-model"),
112+
timeout=1234,
113+
restart_job_on_worker_restart=True,
112114
enable_web_access=True,
113115
sync=False,
114116
)
@@ -147,6 +149,19 @@ def test_end_to_end_tabular(self, shared_state):
147149
# Send online prediction with same instance to both deployed models
148150
# This sample is taken from an observation where median_house_value = 94600
149151
custom_endpoint.wait()
152+
153+
# Check scheduling is correctly set
154+
assert (
155+
custom_job._gca_resource.training_task_inputs["scheduling"]["timeout"]
156+
== "1234s"
157+
)
158+
assert (
159+
custom_job._gca_resource.training_task_inputs["scheduling"][
160+
"restartJobOnWorkerRestart"
161+
]
162+
is True
163+
)
164+
150165
custom_prediction = custom_endpoint.predict([_INSTANCE])
151166

152167
custom_batch_prediction_job.wait()

0 commit comments

Comments
 (0)