Skip to content

Commit 44e208a

Browse files
ivanmkcsasha-gitg
andauthored
feat: Added aiplatform.Model.update method (#952)
* Initial commit for updating models * Added update functionality * Added test * Fixed validation * Fixed docstrings and linting * Fixed whitespace * Mutate copy of proto instead of the original proto * Added return type * Added model.update integration test * Update google/cloud/aiplatform/models.py Co-authored-by: sasha-gitg <[email protected]> * Ran linter Co-authored-by: sasha-gitg <[email protected]>
1 parent 02a92f6 commit 44e208a

File tree

3 files changed

+127
-3
lines changed

3 files changed

+127
-3
lines changed

google/cloud/aiplatform/models.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@
4848
env_var as gca_env_var_compat,
4949
)
5050

51-
from google.protobuf import json_format
52-
51+
from google.protobuf import field_mask_pb2, json_format
5352

5453
_LOGGER = base.Logger(__name__)
5554

@@ -1502,6 +1501,73 @@ def __init__(
15021501
)
15031502
self._gca_resource = self._get_gca_resource(resource_name=model_name)
15041503

1504+
def update(
1505+
self,
1506+
display_name: Optional[str] = None,
1507+
description: Optional[str] = None,
1508+
labels: Optional[Dict[str, str]] = None,
1509+
) -> "Model":
1510+
"""Updates a model.
1511+
1512+
Example usage:
1513+
1514+
my_model = my_model.update(
1515+
display_name='my-model',
1516+
description='my description',
1517+
labels={'key': 'value'},
1518+
)
1519+
1520+
Args:
1521+
display_name (str):
1522+
The display name of the Model. The name can be up to 128
1523+
characters long and can be consist of any UTF-8 characters.
1524+
description (str):
1525+
The description of the model.
1526+
labels (Dict[str, str]):
1527+
Optional. The labels with user-defined metadata to
1528+
organize your Models.
1529+
Label keys and values can be no longer than 64
1530+
characters (Unicode codepoints), can only
1531+
contain lowercase letters, numeric characters,
1532+
underscores and dashes. International characters
1533+
are allowed.
1534+
See https://goo.gl/xmQnxf for more information
1535+
and examples of labels.
1536+
Returns:
1537+
model: Updated model resource.
1538+
Raises:
1539+
ValueError: If `labels` is not the correct format.
1540+
"""
1541+
1542+
current_model_proto = self.gca_resource
1543+
copied_model_proto = current_model_proto.__class__(current_model_proto)
1544+
1545+
update_mask: List[str] = []
1546+
1547+
if display_name:
1548+
utils.validate_display_name(display_name)
1549+
1550+
copied_model_proto.display_name = display_name
1551+
update_mask.append("display_name")
1552+
1553+
if description:
1554+
copied_model_proto.description = description
1555+
update_mask.append("description")
1556+
1557+
if labels:
1558+
utils.validate_labels(labels)
1559+
1560+
copied_model_proto.labels = labels
1561+
update_mask.append("labels")
1562+
1563+
update_mask = field_mask_pb2.FieldMask(paths=update_mask)
1564+
1565+
self.api_client.update_model(model=copied_model_proto, update_mask=update_mask)
1566+
1567+
self._sync_gca_resource()
1568+
1569+
return self
1570+
15051571
# TODO(b/170979552) Add support for predict schemata
15061572
# TODO(b/170979926) Add support for metadata and metadata schema
15071573
@classmethod

tests/system/aiplatform/test_model_upload.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class TestModel(e2e_base.TestEndToEnd):
3737
_temp_prefix = f"{_TEST_PROJECT}-vertex-staging-{_TEST_LOCATION}"
3838

3939
def test_upload_and_deploy_xgboost_model(self, shared_state):
40-
"""Upload XGBoost model from local file and deploy it for prediction."""
40+
"""Upload XGBoost model from local file and deploy it for prediction. Additionally, update model name, description and labels"""
4141

4242
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
4343

@@ -65,3 +65,12 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
6565
shared_state["resources"].append(endpoint)
6666
predict_response = endpoint.predict(instances=[[0, 0, 0]])
6767
assert len(predict_response.predictions) == 1
68+
69+
model = model.update(
70+
display_name="new_name",
71+
description="new_description",
72+
labels={"my_label": "updated"},
73+
)
74+
assert model.display_name == "new_name"
75+
assert model.display_name == "new_description"
76+
assert model.labels == {"my_label": "updated"}

tests/unit/aiplatform/test_models.py

+49
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
encryption_spec as gca_encryption_spec,
5555
)
5656

57+
from google.protobuf import field_mask_pb2
5758

5859
from test_endpoints import create_endpoint_mock # noqa: F401
5960

@@ -177,6 +178,27 @@
177178
_TEST_CONTAINER_REGISTRY_DESTINATION
178179

179180

181+
@pytest.fixture
182+
def mock_model():
183+
model = mock.MagicMock(models.Model)
184+
model.name = _TEST_ID
185+
model._latest_future = None
186+
model._exception = None
187+
model._gca_resource = gca_model.Model(
188+
display_name=_TEST_MODEL_NAME,
189+
description=_TEST_DESCRIPTION,
190+
labels=_TEST_LABEL,
191+
)
192+
yield model
193+
194+
195+
@pytest.fixture
196+
def update_model_mock(mock_model):
197+
with patch.object(model_service_client.ModelServiceClient, "update_model") as mock:
198+
mock.return_value = mock_model
199+
yield mock
200+
201+
180202
@pytest.fixture
181203
def get_endpoint_mock():
182204
with mock.patch.object(
@@ -199,6 +221,7 @@ def get_model_mock():
199221
get_model_mock.return_value = gca_model.Model(
200222
display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME,
201223
)
224+
202225
yield get_model_mock
203226

204227

@@ -1660,3 +1683,29 @@ def test_upload_tensorflow_saved_model_uploads_and_gets_model(
16601683
]
16611684
staged_model_file_name = staged_model_file_path.split("/")[-1]
16621685
assert staged_model_file_name in ["saved_model.pb", "saved_model.pbtxt"]
1686+
1687+
@pytest.mark.usefixtures("get_model_mock")
1688+
def test_update(self, update_model_mock, get_model_mock):
1689+
1690+
test_model = models.Model(_TEST_ID)
1691+
1692+
test_model.update(
1693+
display_name=_TEST_MODEL_NAME,
1694+
description=_TEST_DESCRIPTION,
1695+
labels=_TEST_LABEL,
1696+
)
1697+
1698+
current_model_proto = gca_model.Model(
1699+
display_name=_TEST_MODEL_NAME,
1700+
description=_TEST_DESCRIPTION,
1701+
labels=_TEST_LABEL,
1702+
name=_TEST_MODEL_RESOURCE_NAME,
1703+
)
1704+
1705+
update_mask = field_mask_pb2.FieldMask(
1706+
paths=["display_name", "description", "labels"]
1707+
)
1708+
1709+
update_model_mock.assert_called_once_with(
1710+
model=current_model_proto, update_mask=update_mask
1711+
)

0 commit comments

Comments
 (0)