Skip to content

Commit d0968ea

Browse files
authored
feat: add support for failure_policy in PipelineJob (#1452)
* implement failure_policy * add tests * raise valueerror if failure_policy is invalid
1 parent d778dee commit d0968ea

File tree

6 files changed

+192
-3
lines changed

6 files changed

+192
-3
lines changed

google/cloud/aiplatform/compat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
types.model_evaluation_slice = types.model_evaluation_slice_v1beta1
9797
types.model_service = types.model_service_v1beta1
9898
types.operation = types.operation_v1beta1
99+
types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1
99100
types.pipeline_job = types.pipeline_job_v1beta1
100101
types.pipeline_service = types.pipeline_service_v1beta1
101102
types.pipeline_state = types.pipeline_state_v1beta1
@@ -180,6 +181,7 @@
180181
types.model_evaluation_slice = types.model_evaluation_slice_v1
181182
types.model_service = types.model_service_v1
182183
types.operation = types.operation_v1
184+
types.pipeline_failure_policy = types.pipeline_failure_policy_v1
183185
types.pipeline_job = types.pipeline_job_v1
184186
types.pipeline_service = types.pipeline_service_v1
185187
types.pipeline_state = types.pipeline_state_v1

google/cloud/aiplatform/compat/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
model_evaluation_slice as model_evaluation_slice_v1beta1,
6464
model_service as model_service_v1beta1,
6565
operation as operation_v1beta1,
66+
pipeline_failure_policy as pipeline_failure_policy_v1beta1,
6667
pipeline_job as pipeline_job_v1beta1,
6768
pipeline_service as pipeline_service_v1beta1,
6869
pipeline_state as pipeline_state_v1beta1,
@@ -126,6 +127,7 @@
126127
model_evaluation_slice as model_evaluation_slice_v1,
127128
model_service as model_service_v1,
128129
operation as operation_v1,
130+
pipeline_failure_policy as pipeline_failure_policy_v1,
129131
pipeline_job as pipeline_job_v1,
130132
pipeline_service as pipeline_service_v1,
131133
pipeline_state as pipeline_state_v1,
@@ -191,6 +193,7 @@
191193
model_evaluation_slice_v1,
192194
model_service_v1,
193195
operation_v1,
196+
pipeline_failure_policy_v1beta1,
194197
pipeline_job_v1,
195198
pipeline_service_v1,
196199
pipeline_state_v1,
@@ -253,6 +256,7 @@
253256
model_evaluation_slice_v1beta1,
254257
model_service_v1beta1,
255258
operation_v1beta1,
259+
pipeline_failure_policy_v1beta1,
256260
pipeline_job_v1beta1,
257261
pipeline_service_v1beta1,
258262
pipeline_state_v1beta1,

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
credentials: Optional[auth_credentials.Credentials] = None,
120120
project: Optional[str] = None,
121121
location: Optional[str] = None,
122+
failure_policy: Optional[str] = None,
122123
):
123124
"""Retrieves a PipelineJob resource and instantiates its
124125
representation.
@@ -173,6 +174,15 @@ def __init__(
173174
location (str):
174175
Optional. Location to create PipelineJob. If not set,
175176
location set in aiplatform.init will be used.
177+
failure_policy (str):
178+
Optional. The failure policy - "slow" or "fast".
179+
Currently, the default of a pipeline is that the pipeline will continue to
180+
run until no more tasks can be executed, also known as
181+
PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow").
182+
However, if a pipeline is set to
183+
PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"),
184+
it will stop scheduling any new tasks when a task has failed. Any
185+
scheduled tasks will continue to completion.
176186
177187
Raises:
178188
ValueError: If job_id or labels have incorrect format.
@@ -219,6 +229,7 @@ def __init__(
219229
)
220230
builder.update_pipeline_root(pipeline_root)
221231
builder.update_runtime_parameters(parameter_values)
232+
builder.update_failure_policy(failure_policy)
222233
runtime_config_dict = builder.build()
223234

224235
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb

google/cloud/aiplatform/utils/pipeline_utils.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import json
1919
from typing import Any, Dict, Mapping, Optional, Union
20+
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
2021
import packaging.version
2122

2223

@@ -32,6 +33,7 @@ def __init__(
3233
schema_version: str,
3334
parameter_types: Mapping[str, str],
3435
parameter_values: Optional[Dict[str, Any]] = None,
36+
failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None,
3537
):
3638
"""Creates a PipelineRuntimeConfigBuilder object.
3739
@@ -44,11 +46,20 @@ def __init__(
4446
Required. The mapping from pipeline parameter name to its type.
4547
parameter_values (Dict[str, Any]):
4648
Optional. The mapping from runtime parameter name to its value.
49+
failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
50+
Optional. Represents the failure policy of a pipeline. Currently, the
51+
default of a pipeline is that the pipeline will continue to
52+
run until no more tasks can be executed, also known as
53+
PIPELINE_FAILURE_POLICY_FAIL_SLOW. However, if a pipeline is
54+
set to PIPELINE_FAILURE_POLICY_FAIL_FAST, it will stop
55+
scheduling any new tasks when a task has failed. Any
56+
scheduled tasks will continue to completion.
4757
"""
4858
self._pipeline_root = pipeline_root
4959
self._schema_version = schema_version
5060
self._parameter_types = parameter_types
5161
self._parameter_values = copy.deepcopy(parameter_values or {})
62+
self._failure_policy = failure_policy
5263

5364
@classmethod
5465
def from_job_spec_json(
@@ -80,7 +91,14 @@ def from_job_spec_json(
8091

8192
pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
8293
parameter_values = _parse_runtime_parameters(runtime_config_spec)
83-
return cls(pipeline_root, schema_version, parameter_types, parameter_values)
94+
failure_policy = runtime_config_spec.get("failurePolicy")
95+
return cls(
96+
pipeline_root,
97+
schema_version,
98+
parameter_types,
99+
parameter_values,
100+
failure_policy,
101+
)
84102

85103
def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
86104
"""Updates pipeline_root value.
@@ -111,6 +129,24 @@ def update_runtime_parameters(
111129
parameters[k] = json.dumps(v)
112130
self._parameter_values.update(parameters)
113131

132+
def update_failure_policy(self, failure_policy: Optional[str] = None) -> None:
133+
"""Merges runtime failure policy.
134+
135+
Args:
136+
failure_policy (str):
137+
Optional. The failure policy - "slow" or "fast".
138+
139+
Raises:
140+
ValueError: if failure_policy is not valid.
141+
"""
142+
if failure_policy:
143+
if failure_policy in _FAILURE_POLICY_TO_ENUM_VALUE:
144+
self._failure_policy = _FAILURE_POLICY_TO_ENUM_VALUE[failure_policy]
145+
else:
146+
raise ValueError(
147+
f'failure_policy should be either "slow" or "fast", but got: "{failure_policy}".'
148+
)
149+
114150
def build(self) -> Dict[str, Any]:
115151
"""Build a RuntimeConfig proto.
116152
@@ -128,7 +164,8 @@ def build(self) -> Dict[str, Any]:
128164
parameter_values_key = "parameterValues"
129165
else:
130166
parameter_values_key = "parameters"
131-
return {
167+
168+
runtime_config = {
132169
"gcsOutputDirectory": self._pipeline_root,
133170
parameter_values_key: {
134171
k: self._get_vertex_value(k, v)
@@ -137,6 +174,11 @@ def build(self) -> Dict[str, Any]:
137174
},
138175
}
139176

177+
if self._failure_policy:
178+
runtime_config["failurePolicy"] = self._failure_policy
179+
180+
return runtime_config
181+
140182
def _get_vertex_value(
141183
self, name: str, value: Union[int, float, str, bool, list, dict]
142184
) -> Union[int, float, str, bool, list, dict]:
@@ -205,3 +247,10 @@ def _parse_runtime_parameters(
205247
else:
206248
raise TypeError("Got unknown type of value: {}".format(value))
207249
return result
250+
251+
252+
_FAILURE_POLICY_TO_ENUM_VALUE = {
253+
"slow": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
254+
"fast": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
255+
None: pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_UNSPECIFIED,
256+
}

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.cloud.aiplatform import base
3131
from google.cloud.aiplatform import initializer
3232
from google.cloud.aiplatform import pipeline_jobs
33+
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
3334
from google.cloud import storage
3435
from google.protobuf import json_format
3536

@@ -621,6 +622,99 @@ def test_run_call_pipeline_service_create_with_timeout_not_explicitly_set(
621622
timeout=None,
622623
)
623624

625+
@pytest.mark.parametrize(
626+
"job_spec",
627+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
628+
)
629+
@pytest.mark.parametrize(
630+
"failure_policy",
631+
[
632+
(
633+
"slow",
634+
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
635+
),
636+
(
637+
"fast",
638+
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
639+
),
640+
],
641+
)
642+
@pytest.mark.parametrize("sync", [True, False])
643+
def test_run_call_pipeline_service_create_with_failure_policy(
644+
self,
645+
mock_pipeline_service_create,
646+
mock_pipeline_service_get,
647+
job_spec,
648+
mock_load_yaml_and_json,
649+
failure_policy,
650+
sync,
651+
):
652+
aiplatform.init(
653+
project=_TEST_PROJECT,
654+
staging_bucket=_TEST_GCS_BUCKET_NAME,
655+
location=_TEST_LOCATION,
656+
credentials=_TEST_CREDENTIALS,
657+
)
658+
659+
job = pipeline_jobs.PipelineJob(
660+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
661+
template_path=_TEST_TEMPLATE_PATH,
662+
job_id=_TEST_PIPELINE_JOB_ID,
663+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
664+
enable_caching=True,
665+
failure_policy=failure_policy[0],
666+
)
667+
668+
job.run(
669+
service_account=_TEST_SERVICE_ACCOUNT,
670+
network=_TEST_NETWORK,
671+
sync=sync,
672+
create_request_timeout=None,
673+
)
674+
675+
if not sync:
676+
job.wait()
677+
678+
expected_runtime_config_dict = {
679+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
680+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
681+
"failurePolicy": failure_policy[1],
682+
}
683+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
684+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
685+
686+
job_spec = yaml.safe_load(job_spec)
687+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
688+
689+
# Construct expected request
690+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
691+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
692+
pipeline_spec={
693+
"components": {},
694+
"pipelineInfo": pipeline_spec["pipelineInfo"],
695+
"root": pipeline_spec["root"],
696+
"schemaVersion": "2.1.0",
697+
},
698+
runtime_config=runtime_config,
699+
service_account=_TEST_SERVICE_ACCOUNT,
700+
network=_TEST_NETWORK,
701+
)
702+
703+
mock_pipeline_service_create.assert_called_once_with(
704+
parent=_TEST_PARENT,
705+
pipeline_job=expected_gapic_pipeline_job,
706+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
707+
timeout=None,
708+
)
709+
710+
mock_pipeline_service_get.assert_called_with(
711+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
712+
)
713+
714+
assert job._gca_resource == make_pipeline_job(
715+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
716+
)
717+
624718
@pytest.mark.parametrize(
625719
"job_spec",
626720
[

tests/unit/aiplatform/test_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.api_core import client_options, gapic_v1
2929
from google.cloud import aiplatform
3030
from google.cloud.aiplatform import compat, utils
31+
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
3132
from google.cloud.aiplatform.utils import pipeline_utils, tensorboard_utils, yaml_utils
3233
from google.cloud.aiplatform_v1.services.model_service import (
3334
client as model_service_client_v1,
@@ -454,7 +455,22 @@ def test_pipeline_utils_runtime_config_builder_with_no_op_updates(self):
454455
expected_runtime_config = self.SAMPLE_JOB_SPEC["runtimeConfig"]
455456
assert expected_runtime_config == actual_runtime_config
456457

457-
def test_pipeline_utils_runtime_config_builder_with_merge_updates(self):
458+
@pytest.mark.parametrize(
459+
"failure_policy",
460+
[
461+
(
462+
"slow",
463+
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
464+
),
465+
(
466+
"fast",
467+
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
468+
),
469+
],
470+
)
471+
def test_pipeline_utils_runtime_config_builder_with_merge_updates(
472+
self, failure_policy
473+
):
458474
my_builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
459475
self.SAMPLE_JOB_SPEC
460476
)
@@ -468,6 +484,7 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(self):
468484
"bool_param": True,
469485
}
470486
)
487+
my_builder.update_failure_policy(failure_policy[0])
471488
actual_runtime_config = my_builder.build()
472489

473490
expected_runtime_config = {
@@ -481,9 +498,21 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(self):
481498
"list_param": {"stringValue": "[1, 2, 3]"},
482499
"bool_param": {"stringValue": "true"},
483500
},
501+
"failurePolicy": failure_policy[1],
484502
}
485503
assert expected_runtime_config == actual_runtime_config
486504

505+
def test_pipeline_utils_runtime_config_builder_invalid_failure_policy(self):
506+
my_builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
507+
self.SAMPLE_JOB_SPEC
508+
)
509+
with pytest.raises(ValueError) as e:
510+
my_builder.update_failure_policy("slo")
511+
512+
assert e.match(
513+
regexp=r'failure_policy should be either "slow" or "fast", but got: "slo".'
514+
)
515+
487516
def test_pipeline_utils_runtime_config_builder_parameter_not_found(self):
488517
my_builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
489518
self.SAMPLE_JOB_SPEC

0 commit comments

Comments
 (0)