Skip to content

Commit d442248

Browse files
authored
feat: Add Metadata SDK support and samples for get method (#1516)
* feat: Add get() method to Metadata Resource base class. * Add unit tests for artifact, execution, and context * Add samples for get execution and get artifact * fix lint issues * Fix unit tests.
1 parent f93d19c commit d442248

File tree

9 files changed

+259
-7
lines changed

9 files changed

+259
-7
lines changed

google/cloud/aiplatform/metadata/resource.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,50 @@ def get_or_create(
189189
)
190190
return resource
191191

192+
@classmethod
193+
def get(
194+
cls,
195+
resource_id: str,
196+
metadata_store_id: str = "default",
197+
project: Optional[str] = None,
198+
location: Optional[str] = None,
199+
credentials: Optional[auth_credentials.Credentials] = None,
200+
) -> "_Resource":
201+
"""Retrieves a Metadata resource.
202+
203+
Args:
204+
resource_id (str):
205+
Required. The <resource_id> portion of the resource name with the format:
206+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
207+
metadata_store_id (str):
208+
The <metadata_store_id> portion of the resource name with
209+
the format:
210+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
211+
If not provided, the MetadataStore's ID will be set to "default".
212+
project (str):
213+
Project used to retrieve or create this resource. Overrides project set in
214+
aiplatform.init.
215+
location (str):
216+
Location used to retrieve or create this resource. Overrides location set in
217+
aiplatform.init.
218+
credentials (auth_credentials.Credentials):
219+
Custom credentials used to retrieve or create this resource. Overrides
220+
credentials set in aiplatform.init.
221+
222+
Returns:
223+
resource (_Resource):
224+
Instantiated representation of the managed Metadata resource or None if no resouce was found.
225+
226+
"""
227+
resource = cls._get(
228+
resource_name=resource_id,
229+
metadata_store_id=metadata_store_id,
230+
project=project,
231+
location=location,
232+
credentials=credentials,
233+
)
234+
return resource
235+
192236
def sync_resource(self):
193237
"""Syncs local resource with the resource in metadata store."""
194238
self._gca_resource = getattr(self.api_client, self._getter_method)(

samples/model-builder/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,13 @@ def mock_get_execution(mock_execution):
576576
yield mock_get_execution
577577

578578

579+
@pytest.fixture
580+
def mock_execution_get(mock_execution):
581+
with patch.object(aiplatform.Execution, "get") as mock_execution_get:
582+
mock_execution_get.return_value = mock_execution
583+
yield mock_execution_get
584+
585+
579586
@pytest.fixture
580587
def mock_create_execution(mock_execution):
581588
with patch.object(aiplatform.Execution, "create") as mock_create_execution:
@@ -590,6 +597,13 @@ def mock_get_artifact(mock_artifact):
590597
yield mock_get_artifact
591598

592599

600+
@pytest.fixture
601+
def mock_artifact_get(mock_artifact):
602+
with patch.object(aiplatform.Artifact, "get") as mock_artifact_get:
603+
mock_artifact_get.return_value = mock_artifact
604+
yield mock_artifact_get
605+
606+
593607
@pytest.fixture
594608
def mock_pipeline_job_create(mock_pipeline_job):
595609
with patch.object(aiplatform, "PipelineJob") as mock_pipeline_job_create:

samples/model-builder/experiment_tracking/get_artifact_sample.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
# [START aiplatform_sdk_get_artifact_sample]
1919
def get_artifact_sample(
20-
uri: str,
20+
artifact_id: str,
2121
project: str,
2222
location: str,
2323
):
24-
artifact = aiplatform.Artifact.get_with_uri(
25-
uri=uri, project=project, location=location
24+
artifact = aiplatform.Artifact.get(
25+
resource_id=artifact_id, project=project, location=location
2626
)
2727

2828
return artifact

samples/model-builder/experiment_tracking/get_artifact_sample_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
import test_constants
1818

1919

20-
def test_get_artifact_sample(mock_artifact, mock_get_with_uri):
20+
def test_get_artifact_sample(mock_artifact, mock_artifact_get):
2121
artifact = get_artifact_sample.get_artifact_sample(
22-
uri=test_constants.MODEL_ARTIFACT_URI,
22+
artifact_id=test_constants.RESOURCE_ID,
2323
project=test_constants.PROJECT,
2424
location=test_constants.LOCATION,
2525
)
2626

27-
mock_get_with_uri.assert_called_with(
28-
uri=test_constants.MODEL_ARTIFACT_URI,
27+
mock_artifact_get.assert_called_with(
28+
resource_id=test_constants.RESOURCE_ID,
2929
project=test_constants.PROJECT,
3030
location=test_constants.LOCATION,
3131
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import aiplatform
16+
17+
18+
# [START aiplatform_sdk_get_artifact_with_uri_sample]
19+
def get_artifact_with_uri_sample(
20+
uri: str,
21+
project: str,
22+
location: str,
23+
):
24+
artifact = aiplatform.Artifact.get_with_uri(
25+
uri=uri, project=project, location=location
26+
)
27+
28+
return artifact
29+
30+
31+
# [END aiplatform_sdk_get_artifact_with_uri_sample]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import get_artifact_with_uri_sample
16+
17+
import test_constants
18+
19+
20+
def test_get_artifact_with_uri_sample(mock_artifact, mock_get_with_uri):
21+
artifact = get_artifact_with_uri_sample.get_artifact_with_uri_sample(
22+
uri=test_constants.MODEL_ARTIFACT_URI,
23+
project=test_constants.PROJECT,
24+
location=test_constants.LOCATION,
25+
)
26+
27+
mock_get_with_uri.assert_called_with(
28+
uri=test_constants.MODEL_ARTIFACT_URI,
29+
project=test_constants.PROJECT,
30+
location=test_constants.LOCATION,
31+
)
32+
33+
assert artifact is mock_artifact
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import aiplatform
16+
17+
18+
# [START aiplatform_sdk_get_execution_sample]
19+
def get_execution_sample(
20+
execution_id: str,
21+
project: str,
22+
location: str,
23+
):
24+
execution = aiplatform.Execution.get(
25+
resource_id=execution_id, project=project, location=location
26+
)
27+
28+
return execution
29+
30+
31+
# [END aiplatform_sdk_get_execution_sample]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import get_artifact_sample
16+
17+
import test_constants
18+
19+
20+
def test_get_artifact_sample(mock_artifact, mock_artifact_get):
21+
artifact = get_artifact_sample.get_artifact_sample(
22+
artifact_id=test_constants.RESOURCE_ID,
23+
project=test_constants.PROJECT,
24+
location=test_constants.LOCATION,
25+
)
26+
27+
mock_artifact_get.assert_called_with(
28+
resource_id=test_constants.RESOURCE_ID,
29+
project=test_constants.PROJECT,
30+
location=test_constants.LOCATION,
31+
)
32+
33+
assert artifact is mock_artifact

tests/unit/aiplatform/test_metadata_resources.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,28 @@ def test_get_or_create_context(
419419
expected_context.name = _TEST_CONTEXT_NAME
420420
assert my_context._gca_resource == expected_context
421421

422+
def test_get_context(self, get_context_mock):
423+
aiplatform.init(project=_TEST_PROJECT)
424+
425+
my_context = context.Context.get(
426+
resource_id=_TEST_CONTEXT_ID,
427+
metadata_store_id=_TEST_METADATA_STORE,
428+
)
429+
430+
expected_context = GapicContext(
431+
schema_title=_TEST_SCHEMA_TITLE,
432+
schema_version=_TEST_SCHEMA_VERSION,
433+
display_name=_TEST_DISPLAY_NAME,
434+
description=_TEST_DESCRIPTION,
435+
metadata=_TEST_METADATA,
436+
)
437+
get_context_mock.assert_called_once_with(
438+
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
439+
)
440+
441+
expected_context.name = _TEST_CONTEXT_NAME
442+
assert my_context._gca_resource == expected_context
443+
422444
@pytest.mark.usefixtures("get_context_mock")
423445
@pytest.mark.usefixtures("create_context_mock")
424446
def test_update_context(self, update_context_mock):
@@ -633,6 +655,28 @@ def test_get_or_create_execution(
633655
expected_execution.name = _TEST_EXECUTION_NAME
634656
assert my_execution._gca_resource == expected_execution
635657

658+
def test_get_execution(self, get_execution_mock):
659+
aiplatform.init(project=_TEST_PROJECT)
660+
661+
my_execution = execution.Execution.get(
662+
resource_id=_TEST_EXECUTION_ID,
663+
metadata_store_id=_TEST_METADATA_STORE,
664+
)
665+
666+
expected_execution = GapicExecution(
667+
schema_title=_TEST_SCHEMA_TITLE,
668+
schema_version=_TEST_SCHEMA_VERSION,
669+
display_name=_TEST_DISPLAY_NAME,
670+
description=_TEST_DESCRIPTION,
671+
metadata=_TEST_METADATA,
672+
)
673+
get_execution_mock.assert_called_once_with(
674+
name=_TEST_EXECUTION_NAME, retry=base._DEFAULT_RETRY
675+
)
676+
677+
expected_execution.name = _TEST_EXECUTION_NAME
678+
assert my_execution._gca_resource == expected_execution
679+
636680
@pytest.mark.usefixtures("get_execution_mock")
637681
@pytest.mark.usefixtures("create_execution_mock")
638682
def test_update_execution(self, update_execution_mock):
@@ -883,6 +927,28 @@ def test_get_or_create_artifact(
883927
expected_artifact.name = _TEST_ARTIFACT_NAME
884928
assert my_artifact._gca_resource == expected_artifact
885929

930+
def test_get_artifact(self, get_artifact_mock):
931+
aiplatform.init(project=_TEST_PROJECT)
932+
933+
my_artifact = artifact.Artifact.get(
934+
resource_id=_TEST_ARTIFACT_ID,
935+
metadata_store_id=_TEST_METADATA_STORE,
936+
)
937+
938+
expected_artifact = GapicArtifact(
939+
schema_title=_TEST_SCHEMA_TITLE,
940+
schema_version=_TEST_SCHEMA_VERSION,
941+
display_name=_TEST_DISPLAY_NAME,
942+
description=_TEST_DESCRIPTION,
943+
metadata=_TEST_METADATA,
944+
)
945+
get_artifact_mock.assert_called_once_with(
946+
name=_TEST_ARTIFACT_NAME, retry=base._DEFAULT_RETRY
947+
)
948+
949+
expected_artifact.name = _TEST_ARTIFACT_NAME
950+
assert my_artifact._gca_resource == expected_artifact
951+
886952
@pytest.mark.usefixtures("get_artifact_mock")
887953
@pytest.mark.usefixtures("create_artifact_mock")
888954
def test_update_artifact(self, update_artifact_mock):

0 commit comments

Comments
 (0)