Skip to content

Commit b43851c

Browse files
authored
fix: PipelineJob should only pass bearer tokens for AR URIs (#1717)
When downloading compiled KFP pipelines over HTTPS, we only need to pass a bearer token when we need to authenticate for services like Artifact Registry. We may get unexpected behavior passing this token in all HTTPS requests, which is the current behavior. Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) Fixes b/251143831 🦕
1 parent dde9ba1 commit b43851c

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

google/cloud/aiplatform/utils/yaml_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def load_yaml(
5252
if path.startswith("gs://"):
5353
return _load_yaml_from_gs_uri(path, project, credentials)
5454
elif path.startswith("http://") or path.startswith("https://"):
55-
if _VALID_AR_URL.match(path) or _VALID_HTTPS_URL.match(path):
55+
if _VALID_AR_URL.match(path):
5656
return _load_yaml_from_https_uri(path, credentials)
57+
elif _VALID_HTTPS_URL.match(path):
58+
return _load_yaml_from_https_uri(path)
5759
else:
5860
raise ValueError(
5961
"Invalid HTTPS URI. If not using Artifact Registry, please "

tests/unit/aiplatform/test_utils.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
import json
2222
import os
2323
import textwrap
24-
from typing import Callable, Dict, Optional
24+
from typing import Callable, Dict, Optional, Tuple
2525
from unittest import mock
2626
from unittest.mock import patch
2727
from urllib import request as urllib_request
2828

2929
import pytest
3030
import yaml
3131
from google.api_core import client_options, gapic_v1
32+
from google.auth import credentials
3233
from google.cloud import aiplatform
3334
from google.cloud import storage
3435
from google.cloud.aiplatform import compat, utils
@@ -775,15 +776,15 @@ def json_file(tmp_path):
775776

776777

777778
@pytest.fixture(scope="function")
778-
def mock_request_urlopen(request: str) -> str:
779+
def mock_request_urlopen(request: str) -> Tuple[str, mock.MagicMock]:
779780
data = {"key": "val", "list": ["1", 2, 3.0]}
780781
with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
781782
mock_read_response = mock.MagicMock()
782783
mock_decode_response = mock.MagicMock()
783784
mock_decode_response.return_value = json.dumps(data)
784785
mock_read_response.return_value.decode = mock_decode_response
785786
mock_urlopen.return_value.read = mock_read_response
786-
yield request.param
787+
yield request.param, mock_urlopen
787788

788789

789790
class TestYamlUtils:
@@ -802,10 +803,17 @@ def test_load_yaml_from_local_file__with_json(self, json_file):
802803
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
803804
indirect=True,
804805
)
805-
def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
806-
actual = yaml_utils.load_yaml(mock_request_urlopen)
806+
def test_load_yaml_from_ar_uri_passes_creds(self, mock_request_urlopen):
807+
url, mock_urlopen = mock_request_urlopen
808+
mock_credentials = mock.create_autospec(credentials.Credentials, instance=True)
809+
mock_credentials.valid = True
810+
mock_credentials.token = "some_token"
811+
actual = yaml_utils.load_yaml(url, credentials=mock_credentials)
807812
expected = {"key": "val", "list": ["1", 2, 3.0]}
808813
assert actual == expected
814+
assert mock_urlopen.call_args[0][0].headers == {
815+
"Authorization": "Bearer some_token"
816+
}
809817

810818
@pytest.mark.parametrize(
811819
"mock_request_urlopen",
@@ -816,10 +824,15 @@ def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
816824
],
817825
indirect=True,
818826
)
819-
def test_load_yaml_from_https_uri(self, mock_request_urlopen):
820-
actual = yaml_utils.load_yaml(mock_request_urlopen)
827+
def test_load_yaml_from_https_uri_ignores_creds(self, mock_request_urlopen):
828+
url, mock_urlopen = mock_request_urlopen
829+
mock_credentials = mock.create_autospec(credentials.Credentials, instance=True)
830+
mock_credentials.valid = True
831+
mock_credentials.token = "some_token"
832+
actual = yaml_utils.load_yaml(url, credentials=mock_credentials)
821833
expected = {"key": "val", "list": ["1", 2, 3.0]}
822834
assert actual == expected
835+
assert mock_urlopen.call_args[0][0].headers == {}
823836

824837
@pytest.mark.parametrize(
825838
"uri",

0 commit comments

Comments
 (0)