Skip to content

Commit 9e1d796

Browse files
authored
fix: project/location parsing for nested resources (#1700)
* testing parsing * adding util function * removing print statements * adding changes * using regex and dict * lint check * adding fs test for passing in location and project * comment fix * adding docstring changes * fixing featurestore unit tests * lint
1 parent ed0492e commit 9e1d796

File tree

4 files changed

+94
-2
lines changed

4 files changed

+94
-2
lines changed

google/cloud/aiplatform/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,13 @@ def _list(
10771077
Returns:
10781078
List[VertexAiResourceNoun] - A list of SDK resource objects
10791079
"""
1080+
if parent:
1081+
parent_resources = utils.extract_project_and_location_from_parent(parent)
1082+
if parent_resources:
1083+
project, location = (
1084+
parent_resources["project"],
1085+
parent_resources["location"],
1086+
)
10801087

10811088
resource = cls._empty_constructor(
10821089
project=project, location=location, credentials=credentials

google/cloud/aiplatform/utils/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,34 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona
325325
return (gcs_bucket, gcs_blob_prefix)
326326

327327

328+
def extract_project_and_location_from_parent(
329+
parent: str,
330+
) -> Dict[str, str]:
331+
"""Given a complete parent resource name, return the project and location as a dict.
332+
333+
Example Usage:
334+
335+
parent_resources = extract_project_and_location_from_parent(
336+
"projects/123/locations/us-central1/datasets/456"
337+
)
338+
339+
parent_resources["project"] = "123"
340+
parent_resources["location"] = "us-central1"
341+
342+
Args:
343+
parent (str):
344+
Required. A complete parent resource name.
345+
346+
Returns:
347+
Dict[str, str]
348+
A project, location dict from provided parent resource name.
349+
"""
350+
parent_resources = re.match(
351+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)(/|$)", parent
352+
)
353+
return parent_resources.groupdict() if parent_resources else {}
354+
355+
328356
class ClientWithOverride:
329357
class WrappedClient:
330358
"""Wrapper class for client that creates client at API invocation

tests/unit/aiplatform/test_featurestores.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,23 @@ def test_list_entity_types(self, list_entity_types_mock):
10231023
aiplatform.init(project=_TEST_PROJECT)
10241024

10251025
my_featurestore = aiplatform.Featurestore(
1026-
featurestore_name=_TEST_FEATURESTORE_ID
1026+
featurestore_name=_TEST_FEATURESTORE_ID,
1027+
)
1028+
my_entity_type_list = my_featurestore.list_entity_types()
1029+
1030+
list_entity_types_mock.assert_called_once_with(
1031+
request={"parent": _TEST_FEATURESTORE_NAME}
1032+
)
1033+
assert len(my_entity_type_list) == len(_TEST_ENTITY_TYPE_LIST)
1034+
for my_entity_type in my_entity_type_list:
1035+
assert type(my_entity_type) == aiplatform.EntityType
1036+
1037+
@pytest.mark.usefixtures("get_featurestore_mock")
1038+
def test_list_entity_types_with_no_init(self, list_entity_types_mock):
1039+
my_featurestore = aiplatform.Featurestore(
1040+
featurestore_name=_TEST_FEATURESTORE_ID,
1041+
project=_TEST_PROJECT,
1042+
location=_TEST_LOCATION,
10271043
)
10281044
my_entity_type_list = my_featurestore.list_entity_types()
10291045

@@ -1762,7 +1778,7 @@ def test_update_entity_type(self, update_entity_type_mock):
17621778
@pytest.mark.parametrize(
17631779
"featurestore_name", [_TEST_FEATURESTORE_NAME, _TEST_FEATURESTORE_ID]
17641780
)
1765-
def test_list_entity_types(self, featurestore_name, list_entity_types_mock):
1781+
def test_list_entity_type(self, featurestore_name, list_entity_types_mock):
17661782
aiplatform.init(project=_TEST_PROJECT)
17671783

17681784
my_entity_type_list = aiplatform.EntityType.list(
@@ -1790,6 +1806,23 @@ def test_list_features(self, list_features_mock):
17901806
for my_feature in my_feature_list:
17911807
assert type(my_feature) == aiplatform.Feature
17921808

1809+
@pytest.mark.usefixtures("get_entity_type_mock")
1810+
def test_list_features_with_no_init(self, list_features_mock):
1811+
my_entity_type = aiplatform.EntityType(
1812+
entity_type_name=_TEST_ENTITY_TYPE_ID,
1813+
featurestore_id=_TEST_FEATURESTORE_ID,
1814+
project=_TEST_PROJECT,
1815+
location=_TEST_LOCATION,
1816+
)
1817+
my_feature_list = my_entity_type.list_features()
1818+
1819+
list_features_mock.assert_called_once_with(
1820+
request={"parent": _TEST_ENTITY_TYPE_NAME}
1821+
)
1822+
assert len(my_feature_list) == len(_TEST_FEATURE_LIST)
1823+
for my_feature in my_feature_list:
1824+
assert type(my_feature) == aiplatform.Feature
1825+
17931826
@pytest.mark.parametrize("sync", [True, False])
17941827
@pytest.mark.usefixtures("get_entity_type_mock", "get_feature_mock")
17951828
def test_delete_features(self, delete_feature_mock, sync):

tests/unit/aiplatform/test_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,30 @@ def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple)
320320
assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path)
321321

322322

323+
@pytest.mark.parametrize(
324+
"parent, expected",
325+
[
326+
(
327+
"projects/123/locations/us-central1/datasets/456",
328+
{"project": "123", "location": "us-central1"},
329+
),
330+
(
331+
"projects/123/locations/us-central1/",
332+
{"project": "123", "location": "us-central1"},
333+
),
334+
(
335+
"projects/123/locations/us-central1",
336+
{"project": "123", "location": "us-central1"},
337+
),
338+
("projects/123/locations/", {}),
339+
("projects/123", {}),
340+
],
341+
)
342+
def test_extract_project_and_location_from_parent(parent: str, expected: tuple):
343+
# Given a parent resource name, ensure correct project and location are extracted
344+
assert expected == utils.extract_project_and_location_from_parent(parent)
345+
346+
323347
@pytest.mark.usefixtures("google_auth_mock")
324348
def test_wrapped_client():
325349
test_client_info = gapic_v1.client_info.ClientInfo()

0 commit comments

Comments
 (0)