@@ -638,6 +638,90 @@ def test_call_schedule_service_create_uses_pipeline_job_project_location(
638
638
assert pipeline_job_schedule .project == "managed-pipeline-test"
639
639
assert pipeline_job_schedule .location == "europe-west4"
640
640
641
+ @pytest .mark .parametrize (
642
+ "job_spec" ,
643
+ [_TEST_PIPELINE_SPEC_JSON , _TEST_PIPELINE_SPEC_YAML , _TEST_PIPELINE_JOB ],
644
+ )
645
+ def test_call_schedule_service_create_uses_pipeline_job_labels (
646
+ self ,
647
+ mock_schedule_service_create ,
648
+ mock_pipeline_service_list ,
649
+ mock_schedule_service_get ,
650
+ mock_schedule_bucket_exists ,
651
+ job_spec ,
652
+ mock_load_yaml_and_json ,
653
+ ):
654
+ """Creates a PipelineJobSchedule.
655
+
656
+ Tests that PipelineJobs created through PipelineJobSchedule inherit the labels of the init PipelineJob.
657
+ """
658
+ TEST_PIPELINE_JOB_LABELS = {"name" : "test_xx" }
659
+
660
+ aiplatform .init (
661
+ project = _TEST_PROJECT ,
662
+ staging_bucket = _TEST_GCS_BUCKET_NAME ,
663
+ location = _TEST_LOCATION ,
664
+ credentials = _TEST_CREDENTIALS ,
665
+ )
666
+
667
+ job = pipeline_jobs .PipelineJob (
668
+ display_name = _TEST_PIPELINE_JOB_DISPLAY_NAME ,
669
+ template_path = _TEST_TEMPLATE_PATH ,
670
+ parameter_values = _TEST_PIPELINE_PARAMETER_VALUES ,
671
+ input_artifacts = _TEST_PIPELINE_INPUT_ARTIFACTS ,
672
+ enable_caching = True ,
673
+ labels = TEST_PIPELINE_JOB_LABELS ,
674
+ )
675
+
676
+ pipeline_job_schedule = pipeline_job_schedules .PipelineJobSchedule (
677
+ pipeline_job = job ,
678
+ display_name = _TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME ,
679
+ )
680
+
681
+ pipeline_job_schedule .create (
682
+ cron = _TEST_PIPELINE_JOB_SCHEDULE_CRON ,
683
+ max_concurrent_run_count = _TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT ,
684
+ max_run_count = _TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT ,
685
+ service_account = _TEST_SERVICE_ACCOUNT ,
686
+ network = _TEST_NETWORK ,
687
+ create_request_timeout = None ,
688
+ )
689
+
690
+ expected_runtime_config_dict = {
691
+ "gcsOutputDirectory" : _TEST_GCS_BUCKET_NAME ,
692
+ "parameterValues" : _TEST_PIPELINE_PARAMETER_VALUES ,
693
+ "inputArtifacts" : {"vertex_model" : {"artifactId" : "456" }},
694
+ }
695
+ runtime_config = gca_pipeline_job .PipelineJob .RuntimeConfig ()._pb
696
+ json_format .ParseDict (expected_runtime_config_dict , runtime_config )
697
+
698
+ job_spec = yaml .safe_load (job_spec )
699
+ pipeline_spec = job_spec .get ("pipelineSpec" ) or job_spec
700
+
701
+ # Construct expected request
702
+ expected_gapic_pipeline_job_schedule = gca_schedule .Schedule (
703
+ display_name = _TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME ,
704
+ cron = _TEST_PIPELINE_JOB_SCHEDULE_CRON ,
705
+ max_concurrent_run_count = _TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT ,
706
+ max_run_count = _TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT ,
707
+ create_pipeline_job_request = {
708
+ "parent" : _TEST_PARENT ,
709
+ "pipeline_job" : {
710
+ "runtime_config" : runtime_config ,
711
+ "pipeline_spec" : dict_to_struct (pipeline_spec ),
712
+ "labels" : TEST_PIPELINE_JOB_LABELS ,
713
+ "service_account" : _TEST_SERVICE_ACCOUNT ,
714
+ "network" : _TEST_NETWORK ,
715
+ },
716
+ },
717
+ )
718
+
719
+ mock_schedule_service_create .assert_called_once_with (
720
+ parent = _TEST_PARENT ,
721
+ schedule = expected_gapic_pipeline_job_schedule ,
722
+ timeout = None ,
723
+ )
724
+
641
725
@pytest .mark .parametrize (
642
726
"job_spec" ,
643
727
[_TEST_PIPELINE_SPEC_JSON , _TEST_PIPELINE_SPEC_YAML , _TEST_PIPELINE_JOB ],
0 commit comments