Skip to content

Commit 2cf9fe6

Browse files
rui5ijaycee-li
andauthored
feat: add input artifact when creating a pipeline (#1593)
* Add input artifact * Add input artifact * Add unit tests * Add the example on docstring * update the unit tests * fix the key * update unit test * update the docstring to be more accurate * update the docstring * fix unit test * fix unit test * fix lint Co-authored-by: Jaycee Li <[email protected]>
1 parent 653b759 commit 2cf9fe6

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

google/cloud/aiplatform/pipeline_jobs.py

+15
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(
110110
job_id: Optional[str] = None,
111111
pipeline_root: Optional[str] = None,
112112
parameter_values: Optional[Dict[str, Any]] = None,
113+
input_artifacts: Optional[Dict[str, str]] = None,
113114
enable_caching: Optional[bool] = None,
114115
encryption_spec_key_name: Optional[str] = None,
115116
labels: Optional[Dict[str, str]] = None,
@@ -139,6 +140,9 @@ def __init__(
139140
parameter_values (Dict[str, Any]):
140141
Optional. The mapping from runtime parameter names to its values that
141142
control the pipeline run.
143+
input_artifacts (Dict[str, str]):
144+
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
145+
For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
142146
enable_caching (bool):
143147
Optional. Whether to turn on caching for the run.
144148
@@ -235,6 +239,8 @@ def __init__(
235239
)
236240
builder.update_pipeline_root(pipeline_root)
237241
builder.update_runtime_parameters(parameter_values)
242+
builder.update_input_artifacts(input_artifacts)
243+
238244
builder.update_failure_policy(failure_policy)
239245
runtime_config_dict = builder.build()
240246

@@ -662,6 +668,7 @@ def clone(
662668
job_id: Optional[str] = None,
663669
pipeline_root: Optional[str] = None,
664670
parameter_values: Optional[Dict[str, Any]] = None,
671+
input_artifacts: Optional[Dict[str, str]] = None,
665672
enable_caching: Optional[bool] = None,
666673
encryption_spec_key_name: Optional[str] = None,
667674
labels: Optional[Dict[str, str]] = None,
@@ -685,6 +692,9 @@ def clone(
685692
Optional. The mapping from runtime parameter names to its values that
686693
control the pipeline run. Defaults to be the same values as original
687694
PipelineJob.
695+
input_artifacts (Dict[str, str]):
696+
Optional. The mapping from the runtime parameter name for this artifact to its resource id. Defaults to be the same values as original
697+
PipelineJob. For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
688698
enable_caching (bool):
689699
Optional. Whether to turn on caching for the run.
690700
If this is not set, defaults to be the same as original pipeline.
@@ -785,6 +795,7 @@ def clone(
785795
)
786796
builder.update_pipeline_root(pipeline_root)
787797
builder.update_runtime_parameters(parameter_values)
798+
builder.update_input_artifacts(input_artifacts)
788799
runtime_config_dict = builder.build()
789800
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
790801
json_format.ParseDict(runtime_config_dict, runtime_config)
@@ -805,6 +816,7 @@ def from_pipeline_func(
805816
# Parameters for the PipelineJob constructor
806817
pipeline_func: Callable,
807818
parameter_values: Optional[Dict[str, Any]] = None,
819+
input_artifacts: Optional[Dict[str, str]] = None,
808820
output_artifacts_gcs_dir: Optional[str] = None,
809821
enable_caching: Optional[bool] = None,
810822
context_name: Optional[str] = "pipeline",
@@ -827,6 +839,8 @@ def from_pipeline_func(
827839
parameter_values (Dict[str, Any]):
828840
Optional. The mapping from runtime parameter names to its values that
829841
control the pipeline run.
842+
input_artifacts (Dict[str, str]):
843+
Optional. The mapping from the runtime parameter name for this artifact to its resource id. For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
830844
output_artifacts_gcs_dir (str):
831845
Optional. The GCS location of the pipeline outputs.
832846
A GCS bucket for artifacts will be created if not specified.
@@ -907,6 +921,7 @@ def from_pipeline_func(
907921
pipeline_job = PipelineJob(
908922
template_path=pipeline_file,
909923
parameter_values=parameter_values,
924+
input_artifacts=input_artifacts,
910925
pipeline_root=output_artifacts_gcs_dir,
911926
enable_caching=enable_caching,
912927
display_name=display_name,

google/cloud/aiplatform/utils/pipeline_utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
schema_version: str,
3434
parameter_types: Mapping[str, str],
3535
parameter_values: Optional[Dict[str, Any]] = None,
36+
input_artifacts: Optional[Dict[str, str]] = None,
3637
failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None,
3738
):
3839
"""Creates a PipelineRuntimeConfigBuilder object.
@@ -46,6 +47,8 @@ def __init__(
4647
Required. The mapping from pipeline parameter name to its type.
4748
parameter_values (Dict[str, Any]):
4849
Optional. The mapping from runtime parameter name to its value.
50+
input_artifacts (Dict[str, str]):
51+
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
4952
failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
5053
Optional. Represents the failure policy of a pipeline. Currently, the
5154
default of a pipeline is that the pipeline will continue to
@@ -59,6 +62,7 @@ def __init__(
5962
self._schema_version = schema_version
6063
self._parameter_types = parameter_types
6164
self._parameter_values = copy.deepcopy(parameter_values or {})
65+
self._input_artifacts = copy.deepcopy(input_artifacts or {})
6266
self._failure_policy = failure_policy
6367

6468
@classmethod
@@ -129,6 +133,18 @@ def update_runtime_parameters(
129133
parameters[k] = json.dumps(v)
130134
self._parameter_values.update(parameters)
131135

136+
def update_input_artifacts(
137+
self, input_artifacts: Optional[Mapping[str, str]]
138+
) -> None:
139+
"""Merges runtime input artifacts.
140+
141+
Args:
142+
input_artifacts (Mapping[str, str]):
143+
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
144+
"""
145+
if input_artifacts:
146+
self._input_artifacts.update(input_artifacts)
147+
132148
def update_failure_policy(self, failure_policy: Optional[str] = None) -> None:
133149
"""Merges runtime failure policy.
134150
@@ -172,6 +188,9 @@ def build(self) -> Dict[str, Any]:
172188
for k, v in self._parameter_values.items()
173189
if v is not None
174190
},
191+
"inputArtifacts": {
192+
k: {"artifactId": v} for k, v in self._input_artifacts.items()
193+
},
175194
}
176195

177196
if self._failure_policy:

tests/unit/aiplatform/test_pipeline_jobs.py

+8
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@
7474
"struct_param": {"key1": 12345, "key2": 67890},
7575
}
7676

77+
_TEST_PIPELINE_INPUT_ARTIFACTS = {
78+
"vertex_model": "456",
79+
}
80+
7781
_TEST_PIPELINE_SPEC_LEGACY_JSON = json.dumps(
7882
{
7983
"pipelineInfo": {"name": "my-pipeline"},
@@ -469,6 +473,7 @@ def test_run_call_pipeline_service_create(
469473
template_path=_TEST_TEMPLATE_PATH,
470474
job_id=_TEST_PIPELINE_JOB_ID,
471475
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
476+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
472477
enable_caching=True,
473478
)
474479

@@ -485,6 +490,7 @@ def test_run_call_pipeline_service_create(
485490
expected_runtime_config_dict = {
486491
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
487492
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
493+
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
488494
}
489495
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
490496
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
@@ -1475,6 +1481,7 @@ def test_clone_pipeline_job_with_all_args(
14751481
job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
14761482
pipeline_root=f"cloned-{_TEST_GCS_BUCKET_NAME}",
14771483
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1484+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
14781485
enable_caching=True,
14791486
credentials=_TEST_CREDENTIALS,
14801487
project=_TEST_PROJECT,
@@ -1490,6 +1497,7 @@ def test_clone_pipeline_job_with_all_args(
14901497
expected_runtime_config_dict = {
14911498
"gcsOutputDirectory": f"cloned-{_TEST_GCS_BUCKET_NAME}",
14921499
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
1500+
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
14931501
}
14941502
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
14951503
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

tests/unit/aiplatform/test_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ class TestPipelineUtils:
452452
"int_param": {"intValue": 42},
453453
"float_param": {"doubleValue": 3.14},
454454
},
455+
"inputArtifacts": {},
455456
},
456457
}
457458

@@ -539,6 +540,7 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(
539540
"list_param": {"stringValue": "[1, 2, 3]"},
540541
"bool_param": {"stringValue": "true"},
541542
},
543+
"inputArtifacts": {},
542544
"failurePolicy": failure_policy[1],
543545
}
544546
assert expected_runtime_config == actual_runtime_config

0 commit comments

Comments
 (0)