diff --git a/google/cloud/aiplatform/metadata/resource.py b/google/cloud/aiplatform/metadata/resource.py index 89c145dcbe..fccd0b18d6 100644 --- a/google/cloud/aiplatform/metadata/resource.py +++ b/google/cloud/aiplatform/metadata/resource.py @@ -189,6 +189,50 @@ def get_or_create( ) return resource + @classmethod + def get( + cls, + resource_id: str, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "_Resource": + """Retrieves a Metadata resource. + + Args: + resource_id (str): + Required. The portion of the resource name with the format: + projects/123/locations/us-central1/metadataStores///. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to retrieve or create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to retrieve or create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to retrieve or create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (_Resource): + Instantiated representation of the managed Metadata resource or None if no resouce was found. + + """ + resource = cls._get( + resource_name=resource_id, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + return resource + def sync_resource(self): """Syncs local resource with the resource in metadata store.""" self._gca_resource = getattr(self.api_client, self._getter_method)( diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 3e6e2a1c84..f23d47cacb 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -576,6 +576,13 @@ def mock_get_execution(mock_execution): yield mock_get_execution +@pytest.fixture +def mock_execution_get(mock_execution): + with patch.object(aiplatform.Execution, "get") as mock_execution_get: + mock_execution_get.return_value = mock_execution + yield mock_execution_get + + @pytest.fixture def mock_create_execution(mock_execution): with patch.object(aiplatform.Execution, "create") as mock_create_execution: @@ -590,6 +597,13 @@ def mock_get_artifact(mock_artifact): yield mock_get_artifact +@pytest.fixture +def mock_artifact_get(mock_artifact): + with patch.object(aiplatform.Artifact, "get") as mock_artifact_get: + mock_artifact_get.return_value = mock_artifact + yield mock_artifact_get + + @pytest.fixture def mock_pipeline_job_create(mock_pipeline_job): with patch.object(aiplatform, "PipelineJob") as mock_pipeline_job_create: diff --git a/samples/model-builder/experiment_tracking/get_artifact_sample.py b/samples/model-builder/experiment_tracking/get_artifact_sample.py index e0b9fc5500..93ea031b2c 100644 --- a/samples/model-builder/experiment_tracking/get_artifact_sample.py +++ b/samples/model-builder/experiment_tracking/get_artifact_sample.py @@ -17,12 +17,12 @@ # [START aiplatform_sdk_get_artifact_sample] def get_artifact_sample( - uri: str, + artifact_id: str, project: str, location: str, ): - artifact = aiplatform.Artifact.get_with_uri( - uri=uri, project=project, location=location + artifact = aiplatform.Artifact.get( + resource_id=artifact_id, project=project, location=location ) return artifact diff --git a/samples/model-builder/experiment_tracking/get_artifact_sample_test.py b/samples/model-builder/experiment_tracking/get_artifact_sample_test.py index 7387927ef3..21047e4e7d 100644 --- a/samples/model-builder/experiment_tracking/get_artifact_sample_test.py +++ b/samples/model-builder/experiment_tracking/get_artifact_sample_test.py @@ -17,15 +17,15 @@ import test_constants -def test_get_artifact_sample(mock_artifact, mock_get_with_uri): +def test_get_artifact_sample(mock_artifact, mock_artifact_get): artifact = get_artifact_sample.get_artifact_sample( - uri=test_constants.MODEL_ARTIFACT_URI, + artifact_id=test_constants.RESOURCE_ID, project=test_constants.PROJECT, location=test_constants.LOCATION, ) - mock_get_with_uri.assert_called_with( - uri=test_constants.MODEL_ARTIFACT_URI, + mock_artifact_get.assert_called_with( + resource_id=test_constants.RESOURCE_ID, project=test_constants.PROJECT, location=test_constants.LOCATION, ) diff --git a/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py b/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py new file mode 100644 index 0000000000..6c3f11c9ca --- /dev/null +++ b/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py @@ -0,0 +1,31 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_get_artifact_with_uri_sample] +def get_artifact_with_uri_sample( + uri: str, + project: str, + location: str, +): + artifact = aiplatform.Artifact.get_with_uri( + uri=uri, project=project, location=location + ) + + return artifact + + +# [END aiplatform_sdk_get_artifact_with_uri_sample] diff --git a/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample_test.py b/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample_test.py new file mode 100644 index 0000000000..23be3f9359 --- /dev/null +++ b/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample_test.py @@ -0,0 +1,33 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import get_artifact_with_uri_sample + +import test_constants + + +def test_get_artifact_with_uri_sample(mock_artifact, mock_get_with_uri): + artifact = get_artifact_with_uri_sample.get_artifact_with_uri_sample( + uri=test_constants.MODEL_ARTIFACT_URI, + project=test_constants.PROJECT, + location=test_constants.LOCATION, + ) + + mock_get_with_uri.assert_called_with( + uri=test_constants.MODEL_ARTIFACT_URI, + project=test_constants.PROJECT, + location=test_constants.LOCATION, + ) + + assert artifact is mock_artifact diff --git a/samples/model-builder/experiment_tracking/get_execution_sample.py b/samples/model-builder/experiment_tracking/get_execution_sample.py new file mode 100644 index 0000000000..c724c9973a --- /dev/null +++ b/samples/model-builder/experiment_tracking/get_execution_sample.py @@ -0,0 +1,31 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_get_execution_sample] +def get_execution_sample( + execution_id: str, + project: str, + location: str, +): + execution = aiplatform.Execution.get( + resource_id=execution_id, project=project, location=location + ) + + return execution + + +# [END aiplatform_sdk_get_execution_sample] diff --git a/samples/model-builder/experiment_tracking/get_execution_sample_test.py b/samples/model-builder/experiment_tracking/get_execution_sample_test.py new file mode 100644 index 0000000000..21047e4e7d --- /dev/null +++ b/samples/model-builder/experiment_tracking/get_execution_sample_test.py @@ -0,0 +1,33 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import get_artifact_sample + +import test_constants + + +def test_get_artifact_sample(mock_artifact, mock_artifact_get): + artifact = get_artifact_sample.get_artifact_sample( + artifact_id=test_constants.RESOURCE_ID, + project=test_constants.PROJECT, + location=test_constants.LOCATION, + ) + + mock_artifact_get.assert_called_with( + resource_id=test_constants.RESOURCE_ID, + project=test_constants.PROJECT, + location=test_constants.LOCATION, + ) + + assert artifact is mock_artifact diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py index 1844399e7f..247657e08f 100644 --- a/tests/unit/aiplatform/test_metadata_resources.py +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -419,6 +419,28 @@ def test_get_or_create_context( expected_context.name = _TEST_CONTEXT_NAME assert my_context._gca_resource == expected_context + def test_get_context(self, get_context_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context.Context.get( + resource_id=_TEST_CONTEXT_ID, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_context = GapicContext( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + get_context_mock.assert_called_once_with( + name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY + ) + + expected_context.name = _TEST_CONTEXT_NAME + assert my_context._gca_resource == expected_context + @pytest.mark.usefixtures("get_context_mock") @pytest.mark.usefixtures("create_context_mock") def test_update_context(self, update_context_mock): @@ -633,6 +655,28 @@ def test_get_or_create_execution( expected_execution.name = _TEST_EXECUTION_NAME assert my_execution._gca_resource == expected_execution + def test_get_execution(self, get_execution_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_execution = execution.Execution.get( + resource_id=_TEST_EXECUTION_ID, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_execution = GapicExecution( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + get_execution_mock.assert_called_once_with( + name=_TEST_EXECUTION_NAME, retry=base._DEFAULT_RETRY + ) + + expected_execution.name = _TEST_EXECUTION_NAME + assert my_execution._gca_resource == expected_execution + @pytest.mark.usefixtures("get_execution_mock") @pytest.mark.usefixtures("create_execution_mock") def test_update_execution(self, update_execution_mock): @@ -883,6 +927,28 @@ def test_get_or_create_artifact( expected_artifact.name = _TEST_ARTIFACT_NAME assert my_artifact._gca_resource == expected_artifact + def test_get_artifact(self, get_artifact_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_artifact = artifact.Artifact.get( + resource_id=_TEST_ARTIFACT_ID, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_artifact = GapicArtifact( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + get_artifact_mock.assert_called_once_with( + name=_TEST_ARTIFACT_NAME, retry=base._DEFAULT_RETRY + ) + + expected_artifact.name = _TEST_ARTIFACT_NAME + assert my_artifact._gca_resource == expected_artifact + @pytest.mark.usefixtures("get_artifact_mock") @pytest.mark.usefixtures("create_artifact_mock") def test_update_artifact(self, update_artifact_mock):