Skip to content

Commit e138cfd

Browse files
feat: add support for accepting an Artifact Registry URL in pipeline_job (#1405)
* Add support for Artifact Registry in template_path * fix typo * update tests * fix AR path * remove unused project * add code for refreshing credentials * add import for google.auth.transport * fix AR path * fix AR path * fix runtime_config * test removing v1beta1 * try using v1 directly instead * update to use v1beta1 * use select_version * add back template_uri * try adding back v1beta1 * use select_version * differentiate when to use select_version * test removing v1beta1 for pipeline_complete_states * add tests for creating pipelines using v1beta1 * fix merge * fix typo * fix lint using blacken * fix regex * update to use v1 instead of v1beta1 * add test for invalid url * update error type * implement failure_policy * use urllib.request instead of requests * Revert "implement failure_policy" This reverts commit 72cdd9e. * fix lint Co-authored-by: Anthonios Partheniou <[email protected]>
1 parent 82f678e commit e138cfd

File tree

4 files changed

+177
-11
lines changed

4 files changed

+177
-11
lines changed

google/cloud/aiplatform/pipeline_jobs.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
# Pattern for valid names used as a Vertex resource name.
5757
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")
5858

59+
# Pattern for an Artifact Registry URL.
60+
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
61+
5962

6063
def _get_current_time() -> datetime.datetime:
6164
"""Gets the current timestamp."""
@@ -125,8 +128,9 @@ def __init__(
125128
Required. The user-defined name of this Pipeline.
126129
template_path (str):
127130
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
128-
can be a local path or a Google Cloud Storage URI.
129-
Example: "gs://project.name"
131+
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
132+
or an Artifact Registry URI (e.g.
133+
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
130134
job_id (str):
131135
Optional. The unique ID of the job run.
132136
If not specified, pipeline name + timestamp will be used.
@@ -237,15 +241,20 @@ def __init__(
237241
if enable_caching is not None:
238242
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)
239243

240-
self._gca_resource = gca_pipeline_job.PipelineJob(
241-
display_name=display_name,
242-
pipeline_spec=pipeline_job["pipelineSpec"],
243-
labels=labels,
244-
runtime_config=runtime_config,
245-
encryption_spec=initializer.global_config.get_encryption_spec(
244+
pipeline_job_args = {
245+
"display_name": display_name,
246+
"pipeline_spec": pipeline_job["pipelineSpec"],
247+
"labels": labels,
248+
"runtime_config": runtime_config,
249+
"encryption_spec": initializer.global_config.get_encryption_spec(
246250
encryption_spec_key_name=encryption_spec_key_name
247251
),
248-
)
252+
}
253+
254+
if _VALID_AR_URL.match(template_path):
255+
pipeline_job_args["template_uri"] = template_path
256+
257+
self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)
249258

250259
@base.optional_sync()
251260
def run(

google/cloud/aiplatform/utils/yaml_utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,17 @@
1515
# limitations under the License.
1616
#
1717

18+
import re
1819
from typing import Any, Dict, Optional
20+
from urllib import request
1921

2022
from google.auth import credentials as auth_credentials
23+
from google.auth import transport
2124
from google.cloud import storage
2225

26+
# Pattern for an Artifact Registry URL.
27+
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
28+
2329

2430
def load_yaml(
2531
path: str,
@@ -42,6 +48,8 @@ def load_yaml(
4248
"""
4349
if path.startswith("gs://"):
4450
return _load_yaml_from_gs_uri(path, project, credentials)
51+
elif _VALID_AR_URL.match(path):
52+
return _load_yaml_from_ar_uri(path, credentials)
4553
else:
4654
return _load_yaml_from_local_file(path)
4755

@@ -95,3 +103,37 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
95103
)
96104
with open(file_path) as f:
97105
return yaml.safe_load(f)
106+
107+
108+
def _load_yaml_from_ar_uri(
109+
uri: str,
110+
credentials: Optional[auth_credentials.Credentials] = None,
111+
) -> Dict[str, Any]:
112+
"""Loads data from a YAML document referenced by a Artifact Registry URI.
113+
114+
Args:
115+
path (str):
116+
Required. Artifact Registry URI for YAML document.
117+
credentials (auth_credentials.Credentials):
118+
Optional. Credentials to use with Artifact Registry.
119+
120+
Returns:
121+
A Dict object representing the YAML document.
122+
"""
123+
try:
124+
import yaml
125+
except ImportError:
126+
raise ImportError(
127+
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
128+
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
129+
)
130+
req = request.Request(uri)
131+
132+
if credentials:
133+
if not credentials.valid:
134+
credentials.refresh(transport.requests.Request())
135+
if credentials.token:
136+
req.add_header("Authorization", "Bearer " + credentials.token)
137+
response = request.urlopen(req)
138+
139+
return yaml.safe_load(response.read().decode("utf-8"))

tests/unit/aiplatform/test_pipeline_jobs.py

+92
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from unittest import mock
2323
from importlib import reload
2424
from unittest.mock import patch
25+
from urllib import request
2526
from datetime import datetime
2627

2728
from google.auth import credentials as auth_credentials
@@ -50,6 +51,7 @@
5051
_TEST_SERVICE_ACCOUNT = "[email protected]"
5152

5253
_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
54+
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
5355
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
5456
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"
5557

@@ -289,6 +291,17 @@ def mock_load_yaml_and_json(job_spec):
289291
yield mock_load_yaml_and_json
290292

291293

294+
@pytest.fixture
295+
def mock_request_urlopen(job_spec):
296+
with patch.object(request, "urlopen") as mock_urlopen:
297+
mock_read_response = mock.MagicMock()
298+
mock_decode_response = mock.MagicMock()
299+
mock_decode_response.return_value = job_spec.encode()
300+
mock_read_response.return_value.decode = mock_decode_response
301+
mock_urlopen.return_value.read = mock_read_response
302+
yield mock_urlopen
303+
304+
292305
@pytest.mark.usefixtures("google_auth_mock")
293306
class TestPipelineJob:
294307
def setup_method(self):
@@ -376,6 +389,85 @@ def test_run_call_pipeline_service_create(
376389
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
377390
)
378391

392+
@pytest.mark.parametrize(
393+
"job_spec",
394+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
395+
)
396+
@pytest.mark.parametrize("sync", [True, False])
397+
def test_run_call_pipeline_service_create_artifact_registry(
398+
self,
399+
mock_pipeline_service_create,
400+
mock_pipeline_service_get,
401+
mock_request_urlopen,
402+
job_spec,
403+
mock_load_yaml_and_json,
404+
sync,
405+
):
406+
aiplatform.init(
407+
project=_TEST_PROJECT,
408+
staging_bucket=_TEST_GCS_BUCKET_NAME,
409+
location=_TEST_LOCATION,
410+
credentials=_TEST_CREDENTIALS,
411+
)
412+
413+
job = pipeline_jobs.PipelineJob(
414+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
415+
template_path=_TEST_AR_TEMPLATE_PATH,
416+
job_id=_TEST_PIPELINE_JOB_ID,
417+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
418+
enable_caching=True,
419+
)
420+
421+
job.run(
422+
service_account=_TEST_SERVICE_ACCOUNT,
423+
network=_TEST_NETWORK,
424+
sync=sync,
425+
create_request_timeout=None,
426+
)
427+
428+
if not sync:
429+
job.wait()
430+
431+
expected_runtime_config_dict = {
432+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
433+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
434+
}
435+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
436+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
437+
438+
job_spec = yaml.safe_load(job_spec)
439+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
440+
441+
# Construct expected request
442+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
443+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
444+
pipeline_spec={
445+
"components": {},
446+
"pipelineInfo": pipeline_spec["pipelineInfo"],
447+
"root": pipeline_spec["root"],
448+
"schemaVersion": "2.1.0",
449+
},
450+
runtime_config=runtime_config,
451+
service_account=_TEST_SERVICE_ACCOUNT,
452+
network=_TEST_NETWORK,
453+
template_uri=_TEST_AR_TEMPLATE_PATH,
454+
)
455+
456+
mock_pipeline_service_create.assert_called_once_with(
457+
parent=_TEST_PARENT,
458+
pipeline_job=expected_gapic_pipeline_job,
459+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
460+
timeout=None,
461+
)
462+
463+
mock_pipeline_service_get.assert_called_with(
464+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
465+
)
466+
467+
assert job._gca_resource == make_pipeline_job(
468+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
469+
)
470+
379471
@pytest.mark.parametrize(
380472
"job_spec",
381473
[

tests/unit/aiplatform/test_utils.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import json
2121
import os
2222
from typing import Callable, Dict, Optional
23+
from unittest import mock
24+
from urllib import request
2325

2426
import pytest
2527
import yaml
@@ -564,13 +566,34 @@ def json_file(tmp_path):
564566
yield json_file_path
565567

566568

569+
@pytest.fixture(scope="function")
570+
def mock_request_urlopen():
571+
data = {"key": "val", "list": ["1", 2, 3.0]}
572+
with mock.patch.object(request, "urlopen") as mock_urlopen:
573+
mock_read_response = mock.MagicMock()
574+
mock_decode_response = mock.MagicMock()
575+
mock_decode_response.return_value = json.dumps(data)
576+
mock_read_response.return_value.decode = mock_decode_response
577+
mock_urlopen.return_value.read = mock_read_response
578+
yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
579+
580+
567581
class TestYamlUtils:
568-
def test_load_yaml_from_local_file__with_json(self, yaml_file):
582+
def test_load_yaml_from_local_file__with_yaml(self, yaml_file):
569583
actual = yaml_utils.load_yaml(yaml_file)
570584
expected = {"key": "val", "list": ["1", 2, 3.0]}
571585
assert actual == expected
572586

573-
def test_load_yaml_from_local_file__with_yaml(self, json_file):
587+
def test_load_yaml_from_local_file__with_json(self, json_file):
574588
actual = yaml_utils.load_yaml(json_file)
575589
expected = {"key": "val", "list": ["1", 2, 3.0]}
576590
assert actual == expected
591+
592+
def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
593+
actual = yaml_utils.load_yaml(mock_request_urlopen)
594+
expected = {"key": "val", "list": ["1", 2, 3.0]}
595+
assert actual == expected
596+
597+
def test_load_yaml_from_invalid_uri(self):
598+
with pytest.raises(FileNotFoundError):
599+
yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")

0 commit comments

Comments
 (0)