Skip to content

Commit 69c5f60

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add PipelineJobSchedule update method and unit tests.
PiperOrigin-RevId: 539259661
1 parent 50646be commit 69c5f60

File tree

2 files changed

+216
-1
lines changed

2 files changed

+216
-1
lines changed

google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py

+85-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
)
4242
from google.protobuf import field_mask_pb2 as field_mask
4343

44-
4544
_LOGGER = base.Logger(__name__)
4645

4746
# Pattern for valid names used as a Vertex resource name.
@@ -53,6 +52,8 @@
5352
# Pattern for any JSON or YAML file over HTTPS.
5453
_VALID_HTTPS_URL = schedule_constants._VALID_HTTPS_URL
5554

55+
_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES
56+
5657
_READ_MASK_FIELDS = schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS
5758

5859

@@ -385,3 +386,86 @@ def list_jobs(
385386
location=location,
386387
credentials=credentials,
387388
)
389+
390+
def update(
391+
self,
392+
display_name: Optional[str] = None,
393+
cron_expression: Optional[str] = None,
394+
start_time: Optional[str] = None,
395+
end_time: Optional[str] = None,
396+
allow_queueing: Optional[bool] = None,
397+
max_run_count: Optional[int] = None,
398+
max_concurrent_run_count: Optional[int] = None,
399+
) -> None:
400+
"""Update an existing PipelineJobSchedule.
401+
402+
Example usage:
403+
404+
pipeline_job_schedule.update(
405+
display_name='updated-display-name',
406+
cron_expression='1 2 3 4 5',
407+
)
408+
409+
Args:
410+
display_name (str):
411+
Optional. The user-defined name of this PipelineJobSchedule.
412+
cron_expression (str):
413+
Optional. Time specification (cron schedule expression) to launch scheduled runs.
414+
To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
415+
The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
416+
For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
417+
start_time (str):
418+
Optional. Timestamp after which the first run can be scheduled.
419+
If unspecified, it defaults to the schedule creation timestamp.
420+
end_time (str):
421+
Optional. Timestamp after which no more runs will be scheduled.
422+
If unspecified, then runs will be scheduled indefinitely.
423+
allow_queueing (bool):
424+
Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
425+
max_run_count (int):
426+
Optional. Maximum run count of the schedule.
427+
If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
428+
max_concurrent_run_count (int):
429+
Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
430+
431+
Raises:
432+
RuntimeError: User tried to call update() before create().
433+
"""
434+
pipeline_job_schedule = self._gca_resource
435+
if pipeline_job_schedule.state in _SCHEDULE_ERROR_STATES:
436+
raise RuntimeError(
437+
"Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed."
438+
)
439+
440+
updated_fields = []
441+
if display_name is not None:
442+
updated_fields.append("display_name")
443+
setattr(pipeline_job_schedule, "display_name", display_name)
444+
if cron_expression is not None:
445+
updated_fields.append("cron")
446+
setattr(pipeline_job_schedule, "cron", cron_expression)
447+
if start_time is not None:
448+
updated_fields.append("start_time")
449+
setattr(pipeline_job_schedule, "start_time", start_time)
450+
if end_time is not None:
451+
updated_fields.append("end_time")
452+
setattr(pipeline_job_schedule, "end_time", end_time)
453+
if allow_queueing is not None:
454+
updated_fields.append("allow_queueing")
455+
setattr(pipeline_job_schedule, "allow_queueing", allow_queueing)
456+
if max_run_count is not None:
457+
updated_fields.append("max_run_count")
458+
setattr(pipeline_job_schedule, "max_run_count", max_run_count)
459+
if max_concurrent_run_count is not None:
460+
updated_fields.append("max_concurrent_run_count")
461+
setattr(
462+
pipeline_job_schedule,
463+
"max_concurrent_run_count",
464+
max_concurrent_run_count,
465+
)
466+
467+
update_mask = field_mask.FieldMask(paths=updated_fields)
468+
self.api_client.update_schedule(
469+
schedule=pipeline_job_schedule,
470+
update_mask=update_mask,
471+
)

tests/unit/aiplatform/test_pipeline_job_schedules.py

+131
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@
6969
_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT = 1
7070
_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 2
7171

72+
_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION = "1 1 1 1 1"
73+
_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 5
74+
7275
_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
7376
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
7477
_TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json"
@@ -371,6 +374,23 @@ def mock_pipeline_service_list():
371374
yield mock_list_pipeline_jobs
372375

373376

377+
@pytest.fixture
378+
def mock_schedule_service_update():
379+
with mock.patch.object(
380+
schedule_service_client.ScheduleServiceClient, "update_schedule"
381+
) as mock_update_schedule:
382+
mock_update_schedule.return_value = gca_schedule.Schedule(
383+
name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
384+
state=gca_schedule.Schedule.State.COMPLETED,
385+
create_time=_TEST_PIPELINE_CREATE_TIME,
386+
cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
387+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
388+
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
389+
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
390+
)
391+
yield mock_update_schedule
392+
393+
374394
@pytest.fixture
375395
def mock_load_yaml_and_json(job_spec):
376396
with patch.object(storage.Blob, "download_as_bytes") as mock_load_yaml_and_json:
@@ -1304,3 +1324,114 @@ def test_resume_pipeline_job_schedule_without_created(
13041324
pipeline_job_schedule.resume()
13051325

13061326
assert e.match(regexp=r"Schedule resource has not been created")
1327+
1328+
@pytest.mark.parametrize(
1329+
"job_spec",
1330+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1331+
)
1332+
def test_call_schedule_service_update(
1333+
self,
1334+
mock_schedule_service_create,
1335+
mock_schedule_service_update,
1336+
mock_schedule_service_get,
1337+
mock_schedule_bucket_exists,
1338+
job_spec,
1339+
mock_load_yaml_and_json,
1340+
):
1341+
"""Updates a PipelineJobSchedule.
1342+
1343+
Updates cron_expression and max_run_count.
1344+
"""
1345+
aiplatform.init(
1346+
project=_TEST_PROJECT,
1347+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1348+
location=_TEST_LOCATION,
1349+
credentials=_TEST_CREDENTIALS,
1350+
)
1351+
1352+
job = pipeline_jobs.PipelineJob(
1353+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1354+
template_path=_TEST_TEMPLATE_PATH,
1355+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1356+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
1357+
enable_caching=True,
1358+
)
1359+
1360+
pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
1361+
pipeline_job=job,
1362+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
1363+
)
1364+
1365+
pipeline_job_schedule.create(
1366+
cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
1367+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
1368+
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
1369+
service_account=_TEST_SERVICE_ACCOUNT,
1370+
network=_TEST_NETWORK,
1371+
create_request_timeout=None,
1372+
)
1373+
1374+
pipeline_job_schedule.update(
1375+
cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
1376+
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
1377+
)
1378+
1379+
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
1380+
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
1381+
state=gca_schedule.Schedule.State.COMPLETED,
1382+
create_time=_TEST_PIPELINE_CREATE_TIME,
1383+
cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
1384+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
1385+
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
1386+
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
1387+
)
1388+
assert (
1389+
pipeline_job_schedule._gca_resource == expected_gapic_pipeline_job_schedule
1390+
)
1391+
1392+
@pytest.mark.parametrize(
1393+
"job_spec",
1394+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1395+
)
1396+
def test_call_schedule_service_update_before_create(
1397+
self,
1398+
mock_schedule_service_create,
1399+
mock_schedule_service_update,
1400+
mock_schedule_service_get,
1401+
mock_schedule_bucket_exists,
1402+
job_spec,
1403+
mock_load_yaml_and_json,
1404+
):
1405+
"""Updates a PipelineJobSchedule.
1406+
1407+
Raises error because PipelineJobSchedule should be created before update.
1408+
"""
1409+
aiplatform.init(
1410+
project=_TEST_PROJECT,
1411+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1412+
location=_TEST_LOCATION,
1413+
credentials=_TEST_CREDENTIALS,
1414+
)
1415+
1416+
job = pipeline_jobs.PipelineJob(
1417+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1418+
template_path=_TEST_TEMPLATE_PATH,
1419+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1420+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
1421+
enable_caching=True,
1422+
)
1423+
1424+
pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
1425+
pipeline_job=job,
1426+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
1427+
)
1428+
1429+
with pytest.raises(RuntimeError) as e:
1430+
pipeline_job_schedule.update(
1431+
cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
1432+
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
1433+
)
1434+
1435+
assert e.match(
1436+
regexp=r"Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed."
1437+
)

0 commit comments

Comments
 (0)