Skip to content

Commit d913e1d

Browse files
feat: Add samples for Metadata context list, get, and create (#1525)
* feat: Add samples for context list,get and create * fix lint issues. * Change import path to aiplatform.Context * Fix create mock. * remove duplicate mock method * Update samples/model-builder/experiment_tracking/get_context_sample_test.py Co-authored-by: Dan Lee <[email protected]> * Update samples/model-builder/experiment_tracking/list_context_sample_test.py Co-authored-by: Dan Lee <[email protected]> * Update samples/model-builder/experiment_tracking/list_context_sample_test.py Co-authored-by: Dan Lee <[email protected]> * Update samples/model-builder/experiment_tracking/get_context_sample_test.py Co-authored-by: Dan Lee <[email protected]> * update formatting and comments based on review feedback Co-authored-by: Dan Lee <[email protected]>
1 parent b53e2b5 commit d913e1d

20 files changed

+267
-51
lines changed

samples/model-builder/conftest.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -612,20 +612,38 @@ def mock_get_artifact(mock_artifact):
612612
yield mock_get_artifact
613613

614614

615-
@pytest.fixture
616-
def mock_artifact_get(mock_artifact):
617-
with patch.object(aiplatform.Artifact, "get") as mock_artifact_get:
618-
mock_artifact_get.return_value = mock_artifact
619-
yield mock_artifact_get
620-
621-
622615
@pytest.fixture
623616
def mock_context_get(mock_context):
624617
with patch.object(aiplatform.Context, "get") as mock_context_get:
625618
mock_context_get.return_value = mock_context
626619
yield mock_context_get
627620

628621

622+
@pytest.fixture
623+
def mock_context_list(mock_context):
624+
with patch.object(aiplatform.Context, "list") as mock_context_list:
625+
# Returning list of 2 contexts to avoid confusion with get method
626+
# which returns one unique context.
627+
mock_context_list.return_value = [mock_context, mock_context]
628+
yield mock_context_list
629+
630+
631+
@pytest.fixture
632+
def mock_create_schema_base_context(mock_context):
633+
with patch.object(
634+
aiplatform.metadata.schema.base_context.BaseContextSchema, "create"
635+
) as mock_create_schema_base_context:
636+
mock_create_schema_base_context.return_value = mock_context
637+
yield mock_create_schema_base_context
638+
639+
640+
@pytest.fixture
641+
def mock_artifact_get(mock_artifact):
642+
with patch.object(aiplatform.Artifact, "get") as mock_artifact_get:
643+
mock_artifact_get.return_value = mock_artifact
644+
yield mock_artifact_get
645+
646+
629647
@pytest.fixture
630648
def mock_pipeline_job_create(mock_pipeline_job):
631649
with patch.object(aiplatform, "PipelineJob") as mock_pipeline_job_create:

samples/model-builder/experiment_tracking/create_artifact_sample.py

-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def create_artifact_sample(
4040
project=project,
4141
location=location,
4242
)
43-
4443
return artifact
4544

46-
4745
# [END aiplatform_sdk_create_artifact_sample]

samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py

-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ def create_artifact_sample(
3636
description=description,
3737
metadata=metadata,
3838
)
39-
4039
return system_artifact_schema.create(project=project, location=location,)
4140

42-
4341
# [END aiplatform_sdk_create_artifact_with_sdk_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, Optional
16+
17+
from google.cloud import aiplatform
18+
from google.cloud.aiplatform.metadata.schema.system import context_schema
19+
20+
21+
# [START aiplatform_sdk_create_context_with_sdk_sample]
22+
def create_context_sample(
23+
display_name: str,
24+
project: str,
25+
location: str,
26+
context_id: Optional[str] = None,
27+
metadata: Optional[Dict[str, Any]] = None,
28+
schema_version: Optional[str] = None,
29+
description: Optional[str] = None,
30+
):
31+
aiplatform.init(project=project, location=location)
32+
33+
return context_schema.Experiment(
34+
display_name=display_name,
35+
context_id=context_id,
36+
metadata=metadata,
37+
schema_version=schema_version,
38+
description=description,
39+
).create()
40+
41+
# [END aiplatform_sdk_create_context_with_sdk_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import create_context_with_sdk_sample
16+
17+
import test_constants as constants
18+
19+
20+
def test_create_context_sample(
21+
mock_sdk_init, mock_create_schema_base_context, mock_context,
22+
):
23+
exc = create_context_with_sdk_sample.create_context_sample(
24+
display_name=constants.DISPLAY_NAME,
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
context_id=constants.RESOURCE_ID,
28+
metadata=constants.METADATA,
29+
schema_version=constants.SCHEMA_VERSION,
30+
description=constants.DESCRIPTION,
31+
)
32+
33+
mock_sdk_init.assert_called_with(
34+
project=constants.PROJECT, location=constants.LOCATION,
35+
)
36+
37+
mock_create_schema_base_context.assert_called_with()
38+
assert exc is mock_context

samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py

-1
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,4 @@ def create_execution_sample(
4343
execution.assign_output_artifacts(output_artifacts)
4444
return execution
4545

46-
4746
# [END aiplatform_sdk_create_execution_with_sdk_sample]

samples/model-builder/experiment_tracking/delete_artifact_sample_test.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515
import delete_artifact_sample
1616

17-
import test_constants
17+
import test_constants as constants
1818

1919

2020
def test_delete_artifact_sample(mock_artifact, mock_artifact_get):
2121
delete_artifact_sample.delete_artifact_sample(
22-
artifact_id=test_constants.RESOURCE_ID,
23-
project=test_constants.PROJECT,
24-
location=test_constants.LOCATION,
22+
artifact_id=constants.RESOURCE_ID,
23+
project=constants.PROJECT,
24+
location=constants.LOCATION,
2525
)
2626

2727
mock_artifact_get.assert_called_with(
28-
resource_id=test_constants.RESOURCE_ID,
29-
project=test_constants.PROJECT,
30-
location=test_constants.LOCATION,
28+
resource_id=constants.RESOURCE_ID,
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
3131
)

samples/model-builder/experiment_tracking/delete_context_sample_test.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515
import delete_context_sample
1616

17-
import test_constants
17+
import test_constants as constants
1818

1919

2020
def test_delete_context_sample(mock_context_get):
2121
delete_context_sample.delete_context_sample(
22-
context_id=test_constants.RESOURCE_ID,
23-
project=test_constants.PROJECT,
24-
location=test_constants.LOCATION,
22+
context_id=constants.RESOURCE_ID,
23+
project=constants.PROJECT,
24+
location=constants.LOCATION,
2525
)
2626

2727
mock_context_get.assert_called_with(
28-
resource_id=test_constants.RESOURCE_ID,
29-
project=test_constants.PROJECT,
30-
location=test_constants.LOCATION,
28+
resource_id=constants.RESOURCE_ID,
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
3131
)

samples/model-builder/experiment_tracking/delete_execution_sample_test.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515
import delete_execution_sample
1616

17-
import test_constants
17+
import test_constants as constants
1818

1919

2020
def test_delete_execution_sample(mock_execution, mock_execution_get):
2121
delete_execution_sample.delete_execution_sample(
22-
execution_id=test_constants.RESOURCE_ID,
23-
project=test_constants.PROJECT,
24-
location=test_constants.LOCATION,
22+
execution_id=constants.RESOURCE_ID,
23+
project=constants.PROJECT,
24+
location=constants.LOCATION,
2525
)
2626

2727
mock_execution_get.assert_called_with(
28-
resource_id=test_constants.RESOURCE_ID,
29-
project=test_constants.PROJECT,
30-
location=test_constants.LOCATION,
28+
resource_id=constants.RESOURCE_ID,
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
3131
)

samples/model-builder/experiment_tracking/get_artifact_sample.py

-1
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,4 @@ def get_artifact_sample(
2727

2828
return artifact
2929

30-
3130
# [END aiplatform_sdk_get_artifact_sample]

samples/model-builder/experiment_tracking/get_artifact_sample_test.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@
1414

1515
import get_artifact_sample
1616

17-
import test_constants
17+
import test_constants as constants
1818

1919

2020
def test_get_artifact_sample(mock_artifact, mock_artifact_get):
2121
artifact = get_artifact_sample.get_artifact_sample(
22-
artifact_id=test_constants.RESOURCE_ID,
23-
project=test_constants.PROJECT,
24-
location=test_constants.LOCATION,
22+
artifact_id=constants.RESOURCE_ID,
23+
project=constants.PROJECT,
24+
location=constants.LOCATION,
2525
)
2626

2727
mock_artifact_get.assert_called_with(
28-
resource_id=test_constants.RESOURCE_ID,
29-
project=test_constants.PROJECT,
30-
location=test_constants.LOCATION,
28+
resource_id=constants.RESOURCE_ID,
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
3131
)
3232

3333
assert artifact is mock_artifact

samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py

-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ def get_artifact_with_uri_sample(
2424
artifact = aiplatform.Artifact.get_with_uri(
2525
uri=uri, project=project, location=location
2626
)
27-
2827
return artifact
2928

30-
3129
# [END aiplatform_sdk_get_artifact_with_uri_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import aiplatform
16+
17+
18+
# [START aiplatform_sdk_get_context_sample]
19+
def get_context_sample(
20+
context_id: str,
21+
project: str,
22+
location: str,
23+
):
24+
context = aiplatform.Context.get(
25+
resource_id=context_id, project=project, location=location)
26+
return context
27+
28+
# [END aiplatform_sdk_get_context_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import get_context_sample
16+
17+
import test_constants as constants
18+
19+
20+
def test_get_context_sample(mock_context, mock_context_get):
21+
context = get_context_sample.get_context_sample(
22+
context_id=constants.RESOURCE_ID,
23+
project=constants.PROJECT,
24+
location=constants.LOCATION,
25+
)
26+
27+
mock_context_get.assert_called_with(
28+
resource_id=constants.RESOURCE_ID,
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
31+
)
32+
33+
assert context is mock_context

samples/model-builder/experiment_tracking/get_execution_sample_test.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import get_artifact_sample
15+
import get_execution_sample
1616

1717
import test_constants
1818

1919

20-
def test_get_artifact_sample(mock_artifact, mock_artifact_get):
21-
artifact = get_artifact_sample.get_artifact_sample(
22-
artifact_id=test_constants.RESOURCE_ID,
20+
def test_get_execution_sample(mock_execution, mock_execution_get):
21+
execution = get_execution_sample.get_execution_sample(
22+
execution_id=test_constants.RESOURCE_ID,
2323
project=test_constants.PROJECT,
2424
location=test_constants.LOCATION,
2525
)
2626

27-
mock_artifact_get.assert_called_with(
27+
mock_execution_get.assert_called_with(
2828
resource_id=test_constants.RESOURCE_ID,
2929
project=test_constants.PROJECT,
3030
location=test_constants.LOCATION,
3131
)
3232

33-
assert artifact is mock_artifact
33+
assert execution is mock_execution

samples/model-builder/experiment_tracking/list_artifact_sample.py

-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ def list_artifact_sample(
2929
location=location)
3030

3131
combined_filters = f"{display_name_fitler} AND {create_date_filter}"
32-
3332
return aiplatform.Artifact.list(filter=combined_filters)
3433

35-
3634
# [END aiplatform_sdk_create_artifact_with_sdk_sample]

samples/model-builder/experiment_tracking/list_artifact_sample_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,7 @@ def test_list_artifact_with_sdk_sample(mock_artifact, mock_list_artifact):
2929
filter=f"{constants.DISPLAY_NAME} AND {constants.CREATE_DATE}"
3030
)
3131
assert len(artifacts) == 2
32+
# Returning list of 2 context to avoid confusion with get method
33+
# which returns one unique context.
3234
assert artifacts[0] is mock_artifact
3335
assert artifacts[1] is mock_artifact

0 commit comments

Comments
 (0)