Skip to content

Commit c5a3535

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add preflight validations to PipelineJob submit and run methods.
PiperOrigin-RevId: 651504412
1 parent 42af742 commit c5a3535

File tree

2 files changed

+154
-6
lines changed

2 files changed

+154
-6
lines changed

google/cloud/aiplatform/pipeline_jobs.py

+52-6
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def run(
305305
reserved_ip_ranges: Optional[List[str]] = None,
306306
sync: Optional[bool] = True,
307307
create_request_timeout: Optional[float] = None,
308+
enable_preflight_validations: Optional[bool] = False,
308309
) -> None:
309310
"""Run this configured PipelineJob and monitor the job until completion.
310311
@@ -325,6 +326,8 @@ def run(
325326
Optional. Whether to execute this method synchronously. If False, this method will unblock and it will be executed in a concurrent Future.
326327
create_request_timeout (float):
327328
Optional. The timeout for the create request in seconds.
329+
enable_preflight_validations (bool):
330+
Optional. Whether to enable preflight validations for the PipelineJob.
328331
"""
329332
network = network or initializer.global_config.network
330333

@@ -334,6 +337,7 @@ def run(
334337
reserved_ip_ranges=reserved_ip_ranges,
335338
sync=sync,
336339
create_request_timeout=create_request_timeout,
340+
enable_preflight_validations=enable_preflight_validations,
337341
)
338342

339343
@base.optional_sync()
@@ -344,6 +348,7 @@ def _run(
344348
reserved_ip_ranges: Optional[List[str]] = None,
345349
sync: Optional[bool] = True,
346350
create_request_timeout: Optional[float] = None,
351+
enable_preflight_validations: Optional[bool] = False,
347352
) -> None:
348353
"""Helper method to ensure network synchronization and to run
349354
the configured PipelineJob and monitor the job until completion.
@@ -363,12 +368,15 @@ def _run(
363368
Optional. Whether to execute this method synchronously. If False, this method will unblock and it will be executed in a concurrent Future.
364369
create_request_timeout (float):
365370
Optional. The timeout for the create request in seconds.
371+
enable_preflight_validations (bool):
372+
Optional. Whether to enable preflight validations for the PipelineJob.
366373
"""
367374
self.submit(
368375
service_account=service_account,
369376
network=network,
370377
reserved_ip_ranges=reserved_ip_ranges,
371378
create_request_timeout=create_request_timeout,
379+
enable_preflight_validations=enable_preflight_validations,
372380
)
373381

374382
self._block_until_complete()
@@ -402,6 +410,7 @@ def submit(
402410
create_request_timeout: Optional[float] = None,
403411
*,
404412
experiment: Optional[Union[str, experiment_resources.Experiment]] = None,
413+
enable_preflight_validations: Optional[bool] = False,
405414
) -> None:
406415
"""Run this configured PipelineJob.
407416
@@ -432,6 +441,8 @@ def submit(
432441
433442
Pipeline parameters will be associated as parameters to the
434443
current Experiment Run.
444+
enable_preflight_validations (bool):
445+
Optional. Whether to enable preflight validations for the PipelineJob.
435446
"""
436447
network = network or initializer.global_config.network
437448
service_account = service_account or initializer.global_config.service_account
@@ -471,12 +482,47 @@ def submit(
471482

472483
_LOGGER.log_create_with_lro(self.__class__)
473484

474-
self._gca_resource = self.api_client.create_pipeline_job(
475-
parent=self._parent,
476-
pipeline_job=self._gca_resource,
477-
pipeline_job_id=self.job_id,
478-
timeout=create_request_timeout,
479-
)
485+
if enable_preflight_validations:
486+
self._gca_resource.preflight_validations = True
487+
488+
def extract_error_messages(error_string):
489+
"""
490+
Extracts error messages from a string containing structured errors.
491+
492+
Args:
493+
error_string: The string containing the error data.
494+
495+
Returns:
496+
A list of formatted error messages.
497+
"""
498+
499+
message_pattern = (
500+
r"CreatePipelineJobApiErrorDetail\"\n.*message=(.*),\ cause=null"
501+
)
502+
503+
matches = re.findall(message_pattern, error_string)
504+
505+
formatted_errors = [
506+
f"{i+1}. {message}" for i, message in enumerate(matches)
507+
]
508+
509+
return formatted_errors
510+
511+
try:
512+
self._gca_resource = self.api_client.create_pipeline_job(
513+
parent=self._parent,
514+
pipeline_job=self._gca_resource,
515+
pipeline_job_id=self.job_id,
516+
timeout=create_request_timeout,
517+
)
518+
except Exception as e:
519+
preflight_validations_error_messages = extract_error_messages(str(e))
520+
if preflight_validations_error_messages:
521+
raise Exception(
522+
"PipelineJob Preflight validations failed with the following errors:\n"
523+
+ "\n".join(preflight_validations_error_messages)
524+
) from e
525+
raise
480526

481527
_LOGGER.log_create_complete_with_getter(
482528
self.__class__, self._gca_resource, "pipeline_job"

tests/unit/aiplatform/test_pipeline_jobs.py

+102
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,22 @@
242242
)
243243

244244

245+
@pytest.fixture
246+
def mock_pipeline_service_create_with_preflight_validations():
247+
with mock.patch.object(
248+
pipeline_service_client.PipelineServiceClient, "create_pipeline_job"
249+
) as mock_create_pipeline_job:
250+
mock_create_pipeline_job.return_value = gca_pipeline_job.PipelineJob(
251+
name=_TEST_PIPELINE_JOB_NAME,
252+
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
253+
create_time=_TEST_PIPELINE_CREATE_TIME,
254+
service_account=_TEST_SERVICE_ACCOUNT,
255+
network=_TEST_NETWORK,
256+
reserved_ip_ranges=_TEST_RESERVED_IP_RANGES,
257+
)
258+
yield mock_create_pipeline_job
259+
260+
245261
@pytest.fixture
246262
def mock_pipeline_service_create():
247263
with mock.patch.object(
@@ -2267,3 +2283,89 @@ def test_submit_v1beta1_pipeline_job_returns_response(
22672283
job.submit()
22682284

22692285
assert mock_pipeline_v1beta1_service_create.call_count == 1
2286+
2287+
@pytest.mark.parametrize(
2288+
"job_spec",
2289+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
2290+
)
2291+
@pytest.mark.parametrize("sync", [True, False])
2292+
def test_run_call_pipeline_service_run_with_preflight_validations(
2293+
self,
2294+
mock_pipeline_service_create_with_preflight_validations,
2295+
mock_pipeline_service_get,
2296+
mock_pipeline_bucket_exists,
2297+
job_spec,
2298+
mock_load_yaml_and_json,
2299+
sync,
2300+
):
2301+
import yaml
2302+
2303+
aiplatform.init(
2304+
project=_TEST_PROJECT,
2305+
staging_bucket=_TEST_GCS_BUCKET_NAME,
2306+
location=_TEST_LOCATION,
2307+
credentials=_TEST_CREDENTIALS,
2308+
service_account=_TEST_SERVICE_ACCOUNT,
2309+
network=_TEST_NETWORK,
2310+
)
2311+
2312+
job = pipeline_jobs.PipelineJob(
2313+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
2314+
template_path=_TEST_TEMPLATE_PATH,
2315+
job_id=_TEST_PIPELINE_JOB_ID,
2316+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
2317+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
2318+
enable_caching=True,
2319+
)
2320+
2321+
job.run(
2322+
reserved_ip_ranges=_TEST_RESERVED_IP_RANGES,
2323+
sync=sync,
2324+
create_request_timeout=None,
2325+
enable_preflight_validations=True,
2326+
)
2327+
2328+
if not sync:
2329+
job.wait()
2330+
2331+
expected_runtime_config_dict = {
2332+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
2333+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
2334+
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
2335+
}
2336+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
2337+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
2338+
2339+
job_spec = yaml.safe_load(job_spec)
2340+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
2341+
2342+
# Construct expected request
2343+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
2344+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
2345+
pipeline_spec={
2346+
"components": {},
2347+
"pipelineInfo": pipeline_spec["pipelineInfo"],
2348+
"root": pipeline_spec["root"],
2349+
"schemaVersion": "2.1.0",
2350+
},
2351+
runtime_config=runtime_config,
2352+
service_account=_TEST_SERVICE_ACCOUNT,
2353+
network=_TEST_NETWORK,
2354+
reserved_ip_ranges=_TEST_RESERVED_IP_RANGES,
2355+
preflight_validations=True,
2356+
)
2357+
2358+
mock_pipeline_service_create_with_preflight_validations.assert_called_once_with(
2359+
parent=_TEST_PARENT,
2360+
pipeline_job=expected_gapic_pipeline_job,
2361+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
2362+
timeout=None,
2363+
)
2364+
2365+
mock_pipeline_service_get.assert_called_with(
2366+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
2367+
)
2368+
2369+
assert job._gca_resource == make_pipeline_job(
2370+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
2371+
)

0 commit comments

Comments
 (0)