Skip to content

Commit 926d0b6

Browse files
authored
feat: add support for HTTPS URI pipeline templates (#1683)
First-party pipelines are not yet available in AR, meaning other than using a local file, the only way to access a first-party pipeline is using its GitHub URI. Since support was added for AR URIs, it is not much more effort to support general HTTPS URIs. - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) Fixes b/247878583 🦕
1 parent 2a906c8 commit 926d0b6

File tree

5 files changed

+166
-42
lines changed

5 files changed

+166
-42
lines changed

google/cloud/aiplatform/constants/pipeline.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@
3333
_PIPELINE_ERROR_STATES = set([gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED])
3434

3535
# Pattern for valid names used as a Vertex resource name.
36-
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")
36+
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$", re.IGNORECASE)
3737

3838
# Pattern for an Artifact Registry URL.
39-
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
39+
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*", re.IGNORECASE)
40+
41+
# Pattern for any JSON or YAML file over HTTPS.
42+
_VALID_HTTPS_URL = re.compile(r"^https:\/\/([\.\/\w-]+)\/.*(json|yaml|yml)$")
4043

4144
# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
4245
_READ_MASK_FIELDS = [

google/cloud/aiplatform/pipeline_jobs.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757
# Pattern for an Artifact Registry URL.
5858
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
5959

60+
# Pattern for any JSON or YAML file over HTTPS.
61+
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
62+
6063
_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS
6164

6265

@@ -131,8 +134,8 @@ def __init__(
131134
template_path (str):
132135
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
133136
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
134-
or an Artifact Registry URI (e.g.
135-
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
137+
an Artifact Registry URI (e.g.
138+
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
136139
job_id (str):
137140
Optional. The unique ID of the job run.
138141
If not specified, pipeline name + timestamp will be used.
@@ -277,7 +280,7 @@ def __init__(
277280
),
278281
}
279282

280-
if _VALID_AR_URL.match(template_path):
283+
if _VALID_AR_URL.match(template_path) or _VALID_HTTPS_URL.match(template_path):
281284
pipeline_job_args["template_uri"] = template_path
282285

283286
self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)

google/cloud/aiplatform/utils/yaml_utils.py

+34-30
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,21 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
#
1716

18-
import re
17+
from types import ModuleType
1918
from typing import Any, Dict, Optional
2019
from urllib import request
2120

2221
from google.auth import credentials as auth_credentials
2322
from google.auth import transport
2423
from google.cloud import storage
24+
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
2525

2626
# Pattern for an Artifact Registry URL.
27-
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
27+
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
28+
29+
# Pattern for any JSON or YAML file over HTTPS.
30+
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
2831

2932

3033
def load_yaml(
@@ -36,8 +39,8 @@ def load_yaml(
3639
3740
Args:
3841
path (str):
39-
Required. The path of the YAML document in Google Cloud Storage or
40-
local.
42+
Required. The path of the YAML document. It can be a local path, a
43+
Google Cloud Storage URI, an Artifact Registry URI, or an HTTPS URI.
4144
project (str):
4245
Optional. Project to initiate the Storage client with.
4346
credentials (auth_credentials.Credentials):
@@ -48,12 +51,31 @@ def load_yaml(
4851
"""
4952
if path.startswith("gs://"):
5053
return _load_yaml_from_gs_uri(path, project, credentials)
51-
elif _VALID_AR_URL.match(path):
52-
return _load_yaml_from_ar_uri(path, credentials)
54+
elif path.startswith("http://") or path.startswith("https://"):
55+
if _VALID_AR_URL.match(path) or _VALID_HTTPS_URL.match(path):
56+
return _load_yaml_from_https_uri(path, credentials)
57+
else:
58+
raise ValueError(
59+
"Invalid HTTPS URI. If not using Artifact Registry, please "
60+
"ensure the URI ends with .json, .yaml, or .yml."
61+
)
5362
else:
5463
return _load_yaml_from_local_file(path)
5564

5665

66+
def _maybe_import_yaml() -> ModuleType:
67+
"""Tries to import the PyYAML module."""
68+
try:
69+
import yaml
70+
except ImportError:
71+
raise ImportError(
72+
"PyYAML is not installed and is required to parse PipelineJob or "
73+
'PipelineSpec files. Please install the SDK using "pip install '
74+
'google-cloud-aiplatform[pipelines]"'
75+
)
76+
return yaml
77+
78+
5779
def _load_yaml_from_gs_uri(
5880
uri: str,
5981
project: Optional[str] = None,
@@ -72,13 +94,7 @@ def _load_yaml_from_gs_uri(
7294
Returns:
7395
A Dict object representing the YAML document.
7496
"""
75-
try:
76-
import yaml
77-
except ImportError:
78-
raise ImportError(
79-
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
80-
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
81-
)
97+
yaml = _maybe_import_yaml()
8298
storage_client = storage.Client(project=project, credentials=credentials)
8399
blob = storage.Blob.from_string(uri, storage_client)
84100
return yaml.safe_load(blob.download_as_bytes())
@@ -94,39 +110,27 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
94110
Returns:
95111
A Dict object representing the YAML document.
96112
"""
97-
try:
98-
import yaml
99-
except ImportError:
100-
raise ImportError(
101-
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
102-
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
103-
)
113+
yaml = _maybe_import_yaml()
104114
with open(file_path) as f:
105115
return yaml.safe_load(f)
106116

107117

108-
def _load_yaml_from_ar_uri(
118+
def _load_yaml_from_https_uri(
109119
uri: str,
110120
credentials: Optional[auth_credentials.Credentials] = None,
111121
) -> Dict[str, Any]:
112122
"""Loads data from a YAML document referenced by a Artifact Registry URI.
113123
114124
Args:
115-
path (str):
125+
uri (str):
116126
Required. Artifact Registry URI for YAML document.
117127
credentials (auth_credentials.Credentials):
118128
Optional. Credentials to use with Artifact Registry.
119129
120130
Returns:
121131
A Dict object representing the YAML document.
122132
"""
123-
try:
124-
import yaml
125-
except ImportError:
126-
raise ImportError(
127-
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
128-
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
129-
)
133+
yaml = _maybe_import_yaml()
130134
req = request.Request(uri)
131135

132136
if credentials:

tests/unit/aiplatform/test_pipeline_jobs.py

+83
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060

6161
_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
6262
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
63+
_TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json"
6364
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
6465
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"
6566

@@ -627,6 +628,88 @@ def test_run_call_pipeline_service_create_artifact_registry(
627628
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
628629
)
629630

631+
@pytest.mark.parametrize(
632+
"job_spec",
633+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
634+
)
635+
@pytest.mark.parametrize("sync", [True, False])
636+
def test_run_call_pipeline_service_create_https(
637+
self,
638+
mock_pipeline_service_create,
639+
mock_pipeline_service_get,
640+
mock_pipeline_bucket_exists,
641+
mock_request_urlopen,
642+
job_spec,
643+
mock_load_yaml_and_json,
644+
sync,
645+
):
646+
import yaml
647+
648+
aiplatform.init(
649+
project=_TEST_PROJECT,
650+
staging_bucket=_TEST_GCS_BUCKET_NAME,
651+
location=_TEST_LOCATION,
652+
credentials=_TEST_CREDENTIALS,
653+
)
654+
655+
job = pipeline_jobs.PipelineJob(
656+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
657+
template_path=_TEST_HTTPS_TEMPLATE_PATH,
658+
job_id=_TEST_PIPELINE_JOB_ID,
659+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
660+
enable_caching=True,
661+
)
662+
663+
job.run(
664+
service_account=_TEST_SERVICE_ACCOUNT,
665+
network=_TEST_NETWORK,
666+
sync=sync,
667+
create_request_timeout=None,
668+
)
669+
670+
if not sync:
671+
job.wait()
672+
673+
expected_runtime_config_dict = {
674+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
675+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
676+
}
677+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
678+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
679+
680+
job_spec = yaml.safe_load(job_spec)
681+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
682+
683+
# Construct expected request
684+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
685+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
686+
pipeline_spec={
687+
"components": {},
688+
"pipelineInfo": pipeline_spec["pipelineInfo"],
689+
"root": pipeline_spec["root"],
690+
"schemaVersion": "2.1.0",
691+
},
692+
runtime_config=runtime_config,
693+
service_account=_TEST_SERVICE_ACCOUNT,
694+
network=_TEST_NETWORK,
695+
template_uri=_TEST_HTTPS_TEMPLATE_PATH,
696+
)
697+
698+
mock_pipeline_service_create.assert_called_once_with(
699+
parent=_TEST_PARENT,
700+
pipeline_job=expected_gapic_pipeline_job,
701+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
702+
timeout=None,
703+
)
704+
705+
mock_pipeline_service_get.assert_called_with(
706+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
707+
)
708+
709+
assert job._gca_resource == make_pipeline_job(
710+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
711+
)
712+
630713
@pytest.mark.parametrize(
631714
"job_spec",
632715
[

tests/unit/aiplatform/test_utils.py

+38-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from typing import Callable, Dict, Optional
2525
from unittest import mock
2626
from unittest.mock import patch
27-
from urllib import request
27+
from urllib import request as urllib_request
2828

2929
import pytest
3030
import yaml
@@ -751,15 +751,15 @@ def json_file(tmp_path):
751751

752752

753753
@pytest.fixture(scope="function")
754-
def mock_request_urlopen():
754+
def mock_request_urlopen(request: str) -> str:
755755
data = {"key": "val", "list": ["1", 2, 3.0]}
756-
with mock.patch.object(request, "urlopen") as mock_urlopen:
756+
with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
757757
mock_read_response = mock.MagicMock()
758758
mock_decode_response = mock.MagicMock()
759759
mock_decode_response.return_value = json.dumps(data)
760760
mock_read_response.return_value.decode = mock_decode_response
761761
mock_urlopen.return_value.read = mock_read_response
762-
yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
762+
yield request.param
763763

764764

765765
class TestYamlUtils:
@@ -773,11 +773,42 @@ def test_load_yaml_from_local_file__with_json(self, json_file):
773773
expected = {"key": "val", "list": ["1", 2, 3.0]}
774774
assert actual == expected
775775

776+
@pytest.mark.parametrize(
777+
"mock_request_urlopen",
778+
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
779+
indirect=True,
780+
)
776781
def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
777782
actual = yaml_utils.load_yaml(mock_request_urlopen)
778783
expected = {"key": "val", "list": ["1", 2, 3.0]}
779784
assert actual == expected
780785

781-
def test_load_yaml_from_invalid_uri(self):
782-
with pytest.raises(FileNotFoundError):
783-
yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")
786+
@pytest.mark.parametrize(
787+
"mock_request_urlopen",
788+
[
789+
"https://raw.githubusercontent.com/repo/pipeline.json",
790+
"https://raw.githubusercontent.com/repo/pipeline.yaml",
791+
"https://raw.githubusercontent.com/repo/pipeline.yml",
792+
],
793+
indirect=True,
794+
)
795+
def test_load_yaml_from_https_uri(self, mock_request_urlopen):
796+
actual = yaml_utils.load_yaml(mock_request_urlopen)
797+
expected = {"key": "val", "list": ["1", 2, 3.0]}
798+
assert actual == expected
799+
800+
@pytest.mark.parametrize(
801+
"uri",
802+
[
803+
"https://us-docker.pkg.dev/v2/proj/repo/img/tags/list",
804+
"https://example.com/pipeline.exe",
805+
"http://example.com/pipeline.yaml",
806+
],
807+
)
808+
def test_load_yaml_from_invalid_uri(self, uri: str):
809+
message = (
810+
"Invalid HTTPS URI. If not using Artifact Registry, please "
811+
"ensure the URI ends with .json, .yaml, or .yml."
812+
)
813+
with pytest.raises(ValueError, match=message):
814+
yaml_utils.load_yaml(uri)

0 commit comments

Comments
 (0)