Skip to content

Commit 635ae9c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add pipelineJob create_schedule() method and unit test.
PiperOrigin-RevId: 537427068
1 parent 8ba9e78 commit 635ae9c

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Optional
19+
20+
from google.cloud.aiplatform import base
21+
from google.cloud.aiplatform import pipeline_jobs
22+
from google.cloud.aiplatform import utils
23+
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
24+
from google.cloud.aiplatform.metadata import constants as metadata_constants
25+
from google.cloud.aiplatform.metadata import experiment_resources
26+
27+
_LOGGER = base.Logger(__name__)
28+
29+
_PIPELINE_COMPLETE_STATES = pipeline_constants._PIPELINE_COMPLETE_STATES
30+
31+
_PIPELINE_ERROR_STATES = pipeline_constants._PIPELINE_ERROR_STATES
32+
33+
# Pattern for valid names used as a Vertex resource name.
34+
_VALID_NAME_PATTERN = pipeline_constants._VALID_NAME_PATTERN
35+
36+
# Pattern for an Artifact Registry URL.
37+
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
38+
39+
# Pattern for any JSON or YAML file over HTTPS.
40+
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
41+
42+
_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS
43+
44+
45+
class _PipelineJob(
46+
pipeline_jobs.PipelineJob,
47+
experiment_loggable_schemas=(
48+
experiment_resources._ExperimentLoggableSchema(
49+
title=metadata_constants.SYSTEM_PIPELINE_RUN
50+
),
51+
),
52+
):
53+
"""Preview PipelineJob resource for Vertex AI."""
54+
55+
def create_schedule(
56+
self,
57+
cron_expression: str,
58+
display_name: str,
59+
start_time: Optional[str] = None,
60+
end_time: Optional[str] = None,
61+
allow_queueing: bool = False,
62+
max_run_count: Optional[int] = None,
63+
max_concurrent_run_count: int = 1,
64+
service_account: Optional[str] = None,
65+
network: Optional[str] = None,
66+
create_request_timeout: Optional[float] = None,
67+
) -> "pipeline_job_schedules.PipelineJobSchedule": # noqa: F821
68+
"""Creates a PipelineJobSchedule directly from a PipelineJob.
69+
70+
Example Usage:
71+
72+
pipeline_job = aiplatform.PipelineJob(
73+
display_name='job_display_name',
74+
template_path='your_pipeline.yaml',
75+
)
76+
pipeline_job.run()
77+
pipeline_job_schedule = pipeline_job.create_schedule(
78+
cron_expression='* * * * *',
79+
display_name='schedule_display_name',
80+
)
81+
82+
Args:
83+
cron_expression (str):
84+
Required. Time specification (cron schedule expression) to launch scheduled runs.
85+
To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
86+
The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
87+
For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
88+
display_name (str):
89+
Required. The user-defined name of this PipelineJobSchedule.
90+
start_time (str):
91+
Optional. Timestamp after which the first run can be scheduled.
92+
If unspecified, it defaults to the schedule creation timestamp.
93+
end_time (str):
94+
Optional. Timestamp after which no more runs will be scheduled.
95+
If unspecified, then runs will be scheduled indefinitely.
96+
allow_queueing (bool):
97+
Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
98+
max_run_count (int):
99+
Optional. Maximum run count of the schedule.
100+
If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
101+
max_concurrent_run_count (int):
102+
Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
103+
service_account (str):
104+
Optional. Specifies the service account for workload run-as account.
105+
Users submitting jobs must have act-as permission on this run-as account.
106+
network (str):
107+
Optional. The full name of the Compute Engine network to which the job
108+
should be peered. For example, projects/12345/global/networks/myVPC.
109+
Private services access must already be configured for the network.
110+
If left unspecified, the network set in aiplatform.init will be used.
111+
Otherwise, the job is not peered with any network.
112+
create_request_timeout (float):
113+
Optional. The timeout for the create request in seconds.
114+
115+
Returns:
116+
A Vertex AI PipelineJobSchedule.
117+
"""
118+
from google.cloud.aiplatform.preview.pipelinejobschedule import (
119+
pipeline_job_schedules,
120+
)
121+
122+
if not display_name:
123+
display_name = self._generate_display_name(prefix="PipelineJobSchedule")
124+
utils.validate_display_name(display_name)
125+
126+
pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
127+
pipeline_job=self,
128+
display_name=display_name,
129+
)
130+
131+
pipeline_job_schedule.create(
132+
cron_expression=cron_expression,
133+
start_time=start_time,
134+
end_time=end_time,
135+
allow_queueing=allow_queueing,
136+
max_run_count=max_run_count,
137+
max_concurrent_run_count=max_concurrent_run_count,
138+
service_account=service_account,
139+
network=network,
140+
create_request_timeout=create_request_timeout,
141+
)
142+
return pipeline_job_schedule

tests/unit/aiplatform/test_pipeline_job_schedules.py

+75
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from google.cloud.aiplatform.preview.constants import (
3939
schedules as schedule_constants,
4040
)
41+
from google.cloud.aiplatform.preview.pipelinejob import (
42+
pipeline_jobs as preview_pipeline_jobs,
43+
)
4144
from google.cloud.aiplatform import pipeline_jobs
4245
from google.cloud.aiplatform.preview.pipelinejobschedule import (
4346
pipeline_job_schedules,
@@ -821,6 +824,78 @@ def test_call_schedule_service_create_with_timeout_not_explicitly_set(
821824
timeout=None,
822825
)
823826

827+
@pytest.mark.parametrize(
828+
"job_spec",
829+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
830+
)
831+
def test_call_pipeline_job_create_schedule(
832+
self,
833+
mock_schedule_service_create,
834+
mock_schedule_service_get,
835+
job_spec,
836+
mock_load_yaml_and_json,
837+
):
838+
"""Creates a PipelineJobSchedule via PipelineJob.create_schedule()."""
839+
aiplatform.init(
840+
project=_TEST_PROJECT,
841+
staging_bucket=_TEST_GCS_BUCKET_NAME,
842+
location=_TEST_LOCATION,
843+
credentials=_TEST_CREDENTIALS,
844+
)
845+
846+
job = preview_pipeline_jobs._PipelineJob(
847+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
848+
template_path=_TEST_TEMPLATE_PATH,
849+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
850+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
851+
enable_caching=True,
852+
)
853+
854+
pipeline_job_schedule = job.create_schedule(
855+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
856+
cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
857+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
858+
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
859+
service_account=_TEST_SERVICE_ACCOUNT,
860+
network=_TEST_NETWORK,
861+
)
862+
863+
expected_runtime_config_dict = {
864+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
865+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
866+
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
867+
}
868+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
869+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
870+
871+
job_spec = yaml.safe_load(job_spec)
872+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
873+
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
874+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
875+
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
876+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
877+
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
878+
create_pipeline_job_request={
879+
"parent": _TEST_PARENT,
880+
"pipeline_job": {
881+
"runtime_config": runtime_config,
882+
"pipeline_spec": {"fields": pipeline_spec},
883+
"service_account": _TEST_SERVICE_ACCOUNT,
884+
"network": _TEST_NETWORK,
885+
},
886+
},
887+
)
888+
889+
mock_schedule_service_create.assert_called_once_with(
890+
parent=_TEST_PARENT,
891+
schedule=expected_gapic_pipeline_job_schedule,
892+
timeout=None,
893+
)
894+
895+
assert pipeline_job_schedule._gca_resource == make_schedule(
896+
gca_schedule.Schedule.State.COMPLETED
897+
)
898+
824899
@pytest.mark.usefixtures("mock_schedule_service_get")
825900
def test_get_schedule(self, mock_schedule_service_get):
826901
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)