Skip to content

Commit a34533f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Make PipelineJobSchedule propagate labels to created PipelineJobs
PiperOrigin-RevId: 585812646
1 parent 3f56ae7 commit a34533f

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

google/cloud/aiplatform/pipeline_job_schedules.py

+4
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def __init__(
102102
create_pipeline_job_request["pipeline_job"][
103103
"template_uri"
104104
] = pipeline_job._gca_resource.template_uri
105+
if "labels" in pipeline_job._gca_resource:
106+
create_pipeline_job_request["pipeline_job"][
107+
"labels"
108+
] = pipeline_job._gca_resource.labels
105109
pipeline_job_schedule_args = {
106110
"display_name": display_name,
107111
"create_pipeline_job_request": create_pipeline_job_request,

tests/unit/aiplatform/test_pipeline_job_schedules.py

+84
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,90 @@ def test_call_schedule_service_create_uses_pipeline_job_project_location(
638638
assert pipeline_job_schedule.project == "managed-pipeline-test"
639639
assert pipeline_job_schedule.location == "europe-west4"
640640

641+
@pytest.mark.parametrize(
642+
"job_spec",
643+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
644+
)
645+
def test_call_schedule_service_create_uses_pipeline_job_labels(
646+
self,
647+
mock_schedule_service_create,
648+
mock_pipeline_service_list,
649+
mock_schedule_service_get,
650+
mock_schedule_bucket_exists,
651+
job_spec,
652+
mock_load_yaml_and_json,
653+
):
654+
"""Creates a PipelineJobSchedule.
655+
656+
Tests that PipelineJobs created through PipelineJobSchedule inherit the labels of the init PipelineJob.
657+
"""
658+
TEST_PIPELINE_JOB_LABELS = {"name": "test_xx"}
659+
660+
aiplatform.init(
661+
project=_TEST_PROJECT,
662+
staging_bucket=_TEST_GCS_BUCKET_NAME,
663+
location=_TEST_LOCATION,
664+
credentials=_TEST_CREDENTIALS,
665+
)
666+
667+
job = pipeline_jobs.PipelineJob(
668+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
669+
template_path=_TEST_TEMPLATE_PATH,
670+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
671+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
672+
enable_caching=True,
673+
labels=TEST_PIPELINE_JOB_LABELS,
674+
)
675+
676+
pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
677+
pipeline_job=job,
678+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
679+
)
680+
681+
pipeline_job_schedule.create(
682+
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
683+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
684+
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
685+
service_account=_TEST_SERVICE_ACCOUNT,
686+
network=_TEST_NETWORK,
687+
create_request_timeout=None,
688+
)
689+
690+
expected_runtime_config_dict = {
691+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
692+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
693+
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
694+
}
695+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
696+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
697+
698+
job_spec = yaml.safe_load(job_spec)
699+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
700+
701+
# Construct expected request
702+
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
703+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
704+
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
705+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
706+
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
707+
create_pipeline_job_request={
708+
"parent": _TEST_PARENT,
709+
"pipeline_job": {
710+
"runtime_config": runtime_config,
711+
"pipeline_spec": dict_to_struct(pipeline_spec),
712+
"labels": TEST_PIPELINE_JOB_LABELS,
713+
"service_account": _TEST_SERVICE_ACCOUNT,
714+
"network": _TEST_NETWORK,
715+
},
716+
},
717+
)
718+
719+
mock_schedule_service_create.assert_called_once_with(
720+
parent=_TEST_PARENT,
721+
schedule=expected_gapic_pipeline_job_schedule,
722+
timeout=None,
723+
)
724+
641725
@pytest.mark.parametrize(
642726
"job_spec",
643727
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],

0 commit comments

Comments
 (0)