Skip to content

Commit 72fd36d

Browse files
inardiniandrewferlitschnayaknishant
authored
docs(samples): add Model Registry samples to Vertex AI Python SDK (#1602)
* docs(samples): add samples for vertex ai model registry python sdk * linter test fix * fix list model versions mock test * nox passed Co-authored-by: Andrew Ferlitsch <[email protected]> Co-authored-by: nayaknishant <[email protected]>
1 parent 0b48b50 commit 72fd36d

28 files changed

+1315
-11
lines changed

samples/model-builder/conftest.py

+79-11
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def mock_create_image_dataset(mock_image_dataset):
114114
@pytest.fixture
115115
def mock_create_tabular_dataset(mock_tabular_dataset):
116116
with patch.object(
117-
aiplatform.TabularDataset, "create"
117+
aiplatform.TabularDataset, "create"
118118
) as mock_create_tabular_dataset:
119119
mock_create_tabular_dataset.return_value = mock_tabular_dataset
120120
yield mock_create_tabular_dataset
@@ -123,7 +123,7 @@ def mock_create_tabular_dataset(mock_tabular_dataset):
123123
@pytest.fixture
124124
def mock_create_time_series_dataset(mock_time_series_dataset):
125125
with patch.object(
126-
aiplatform.TimeSeriesDataset, "create"
126+
aiplatform.TimeSeriesDataset, "create"
127127
) as mock_create_time_series_dataset:
128128
mock_create_time_series_dataset.return_value = mock_time_series_dataset
129129
yield mock_create_time_series_dataset
@@ -462,7 +462,7 @@ def mock_get_entity_type(mock_entity_type):
462462
@pytest.fixture
463463
def mock_create_featurestore(mock_featurestore):
464464
with patch.object(
465-
aiplatform.featurestore.Featurestore, "create"
465+
aiplatform.featurestore.Featurestore, "create"
466466
) as mock_create_featurestore:
467467
mock_create_featurestore.return_value = mock_featurestore
468468
yield mock_create_featurestore
@@ -471,7 +471,7 @@ def mock_create_featurestore(mock_featurestore):
471471
@pytest.fixture
472472
def mock_create_entity_type(mock_entity_type):
473473
with patch.object(
474-
aiplatform.featurestore.EntityType, "create"
474+
aiplatform.featurestore.EntityType, "create"
475475
) as mock_create_entity_type:
476476
mock_create_entity_type.return_value = mock_entity_type
477477
yield mock_create_entity_type
@@ -499,7 +499,7 @@ def mock_batch_serve_to_bq(mock_featurestore):
499499
@pytest.fixture
500500
def mock_batch_create_features(mock_entity_type):
501501
with patch.object(
502-
mock_entity_type, "batch_create_features"
502+
mock_entity_type, "batch_create_features"
503503
) as mock_batch_create_features:
504504
yield mock_batch_create_features
505505

@@ -513,7 +513,7 @@ def mock_read_feature_values(mock_entity_type):
513513
@pytest.fixture
514514
def mock_import_feature_values(mock_entity_type):
515515
with patch.object(
516-
mock_entity_type, "ingest_from_gcs"
516+
mock_entity_type, "ingest_from_gcs"
517517
) as mock_import_feature_values:
518518
yield mock_import_feature_values
519519

@@ -644,7 +644,7 @@ def mock_context_list(mock_context):
644644
@pytest.fixture
645645
def mock_create_schema_base_context(mock_context):
646646
with patch.object(
647-
aiplatform.metadata.schema.base_context.BaseContextSchema, "create"
647+
aiplatform.metadata.schema.base_context.BaseContextSchema, "create"
648648
) as mock_create_schema_base_context:
649649
mock_create_schema_base_context.return_value = mock_context
650650
yield mock_create_schema_base_context
@@ -702,7 +702,7 @@ def mock_create_artifact(mock_artifact):
702702
@pytest.fixture
703703
def mock_create_schema_base_artifact(mock_artifact):
704704
with patch.object(
705-
aiplatform.metadata.schema.base_artifact.BaseArtifactSchema, "create"
705+
aiplatform.metadata.schema.base_artifact.BaseArtifactSchema, "create"
706706
) as mock_create_schema_base_artifact:
707707
mock_create_schema_base_artifact.return_value = mock_artifact
708708
yield mock_create_schema_base_artifact
@@ -711,7 +711,7 @@ def mock_create_schema_base_artifact(mock_artifact):
711711
@pytest.fixture
712712
def mock_create_schema_base_execution(mock_execution):
713713
with patch.object(
714-
aiplatform.metadata.schema.base_execution.BaseExecutionSchema, "create"
714+
aiplatform.metadata.schema.base_execution.BaseExecutionSchema, "create"
715715
) as mock_create_schema_base_execution:
716716
mock_create_schema_base_execution.return_value = mock_execution
717717
yield mock_create_schema_base_execution
@@ -757,7 +757,7 @@ def mock_log_metrics():
757757
@pytest.fixture
758758
def mock_log_time_series_metrics():
759759
with patch.object(
760-
aiplatform, "log_time_series_metrics"
760+
aiplatform, "log_time_series_metrics"
761761
) as mock_log_time_series_metrics:
762762
mock_log_time_series_metrics.return_value = None
763763
yield mock_log_time_series_metrics
@@ -822,7 +822,75 @@ def mock_get_params(mock_params, mock_experiment_run):
822822
@pytest.fixture
823823
def mock_get_time_series_metrics(mock_time_series_metrics, mock_experiment_run):
824824
with patch.object(
825-
mock_experiment_run, "get_time_series_data_frame"
825+
mock_experiment_run, "get_time_series_data_frame"
826826
) as mock_get_time_series_metrics:
827827
mock_get_time_series_metrics.return_value = mock_time_series_metrics
828828
yield mock_get_time_series_metrics
829+
830+
831+
"""
832+
----------------------------------------------------------------------------
833+
Model Versioning Fixtures
834+
----------------------------------------------------------------------------
835+
"""
836+
837+
838+
@pytest.fixture
839+
def mock_model_registry():
840+
mock = MagicMock(aiplatform.models.ModelRegistry)
841+
yield mock
842+
843+
844+
@pytest.fixture
845+
def mock_version_info():
846+
mock = MagicMock(aiplatform.models.VersionInfo)
847+
yield mock
848+
849+
850+
@pytest.fixture
851+
def mock_init_model_registry(mock_model_registry):
852+
with patch.object(aiplatform.models, "ModelRegistry") as mock:
853+
mock.return_value = mock_model_registry
854+
yield mock
855+
856+
857+
@pytest.fixture
858+
def mock_get_model(mock_model_registry):
859+
with patch.object(mock_model_registry, "get_model") as mock_get_model:
860+
mock_get_model.return_value = mock_model
861+
yield mock_get_model
862+
863+
864+
@pytest.fixture
865+
def mock_get_model_version_info(mock_model_registry):
866+
with patch.object(mock_model_registry, "get_version_info") as mock_get_model_version_info:
867+
mock_get_model_version_info.return_value = mock_version_info
868+
yield mock_get_model_version_info
869+
870+
871+
@pytest.fixture
872+
def mock_list_versions(mock_model_registry, mock_version_info):
873+
with patch.object(mock_model_registry, "list_versions") as mock_list_versions:
874+
mock_list_versions.return_value = [mock_version_info, mock_version_info]
875+
yield mock_list_versions
876+
877+
878+
@pytest.fixture
879+
def mock_delete_version(mock_model_registry):
880+
with patch.object(mock_model_registry, "delete_version") as mock_delete_version:
881+
mock_delete_version.return_value = None
882+
yield mock_delete_version
883+
884+
885+
@pytest.fixture
886+
def mock_add_version_aliases(mock_model_registry):
887+
with patch.object(mock_model_registry, "add_version_aliases") as mock_add_version_aliases:
888+
mock_add_version_aliases.return_value = None
889+
yield mock_add_version_aliases
890+
891+
892+
@pytest.fixture
893+
def mock_remove_version_aliases(mock_model_registry):
894+
with patch.object(mock_model_registry, "remove_version_aliases") as mock_remove_version_aliases:
895+
mock_remove_version_aliases.return_value = None
896+
yield mock_remove_version_aliases
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_model_registry_assign_aliases_model_version_sample]
16+
17+
from typing import List
18+
19+
from google.cloud import aiplatform
20+
21+
22+
def assign_aliases_model_version_sample(
23+
model_id: str,
24+
version_aliases: List[str],
25+
version_id: str,
26+
project: str,
27+
location: str,
28+
):
29+
"""
30+
Assign aliases to a model version.
31+
Args:
32+
model_id: The ID of the model.
33+
version_aliases: The version aliases to assign.
34+
version_id: The version ID of the model to assign the aliases to.
35+
project: The project name.
36+
location: The location name.
37+
Returns
38+
None.
39+
"""
40+
# Initialize the client.
41+
aiplatform.init(project=project, location=location)
42+
43+
# Initialize the Model Registry resource with the ID 'model_id'.The parent_name of create method can be also
44+
# 'projects/<your-project-id>/locations/<your-region>/models/<your-model-id>'
45+
model_registry = aiplatform.models.ModelRegistry(model=model_id)
46+
47+
# Assign the version aliases to the model with the version 'version_id'.
48+
model_registry.add_version_aliases(new_aliases=version_aliases, version=version_id)
49+
50+
51+
# [END aiplatform_model_registry_assign_aliases_model_version_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import assign_aliases_model_version_sample
17+
18+
import test_constants as constants
19+
20+
21+
def test_assign_aliases_model_version_sample(
22+
mock_sdk_init, mock_init_model_registry, mock_add_version_aliases, mock_model
23+
):
24+
25+
# Assign aliases to a model version.
26+
assign_aliases_model_version_sample.assign_aliases_model_version_sample(
27+
model_id=constants.MODEL_NAME,
28+
version_id=constants.VERSION_ID,
29+
version_aliases=constants.VERSION_ALIASES,
30+
project=constants.PROJECT,
31+
location=constants.LOCATION,
32+
)
33+
34+
# Check client initialization.
35+
mock_sdk_init.assert_called_with(
36+
project=constants.PROJECT, location=constants.LOCATION
37+
)
38+
39+
# Check model registry initialization.
40+
mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME)
41+
42+
# Check that the model version was assigned the aliases.
43+
mock_add_version_aliases.assert_called_with(
44+
new_aliases=constants.VERSION_ALIASES,
45+
version=constants.VERSION_ID,
46+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_model_registry_create_aliased_model_sample]
16+
17+
from google.cloud import aiplatform
18+
19+
20+
def create_aliased_model_sample(
21+
model_id: str, version_id: str, project: str, location: str
22+
):
23+
"""
24+
Initialize a Model resource to represent an existing model version with custom alias.
25+
Args:
26+
model_id: The ID of the model to initialize. Parent resource name of the model is also accepted.
27+
version_id: The version ID or version alias of the model to initialize.
28+
project: The project.
29+
location: The location.
30+
Returns:
31+
Model resource.
32+
"""
33+
# Initialize the client.
34+
aiplatform.init(project=project, location=location)
35+
36+
# Initialize the Model resource with the ID 'model_id'. The version can be also provided using @ annotation in
37+
# the parent resource name:
38+
# 'projects/<your-project-id>/locations/<your-region>/models/<your-model-id>@<your-version-id>'.
39+
40+
aliased_model = aiplatform.Model(model_name=model_id, version=version_id)
41+
42+
return aliased_model
43+
44+
45+
# [END aiplatform_model_registry_create_aliased_model_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_aliased_model_sample
17+
18+
import test_constants as constants
19+
20+
21+
def test_create_aliased_model_sample(mock_sdk_init, mock_init_model):
22+
# Create a model with alias 'default'.
23+
create_aliased_model_sample.create_aliased_model_sample(
24+
model_id=constants.MODEL_NAME,
25+
version_id=constants.VERSION_ID,
26+
project=constants.PROJECT,
27+
location=constants.LOCATION,
28+
)
29+
30+
# Check client initialization.
31+
mock_sdk_init.assert_called_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
35+
# Check that the model was created.
36+
mock_init_model.assert_called_with(
37+
model_name=constants.MODEL_NAME, version=constants.VERSION_ID
38+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_model_registry_create_default_model_sample]
16+
17+
from google.cloud import aiplatform
18+
19+
20+
def create_default_model_sample(model_id: str, project: str, location: str):
21+
"""
22+
Initialize a Model resource to represent an existing model version with alias 'default'.
23+
Args:
24+
model_id: The ID of the model to initialize. Parent resource name of the model is also accepted.
25+
project: The project.
26+
location: The location.
27+
Returns:
28+
Model resource.
29+
"""
30+
# Initialize the client.
31+
aiplatform.init(project=project, location=location)
32+
33+
# Initialize the Model resource with the ID 'model_id'. The parent_name of create method can be also
34+
# 'projects/<your-project-id>/locations/<your-region>/models/<your-model-id>'
35+
default_model = aiplatform.Model(model_name=model_id)
36+
37+
return default_model
38+
39+
40+
# [END aiplatform_model_registry_create_default_model_sample]

0 commit comments

Comments
 (0)