Skip to content

Commit 8621e24

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add update_version to Model Registry
PiperOrigin-RevId: 506139032
1 parent 7ab6e0b commit 8621e24

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

google/cloud/aiplatform/models.py

+54
Original file line numberDiff line numberDiff line change
@@ -4905,6 +4905,60 @@ def delete_version(
49054905

49064906
_LOGGER.info(f"Deleted version {version} for {self.model_resource_name}")
49074907

4908+
def update_version(
4909+
self,
4910+
version: str,
4911+
version_description: Optional[str] = None,
4912+
labels: Optional[Dict[str, str]] = None,
4913+
) -> None:
4914+
"""Updates a model version.
4915+
4916+
Args:
4917+
version (str): Required. The version ID to receive the new alias(es).
4918+
version_description (str):
4919+
The description of the model version.
4920+
labels (Dict[str, str]):
4921+
Optional. The labels with user-defined metadata to
4922+
organize your Model versions.
4923+
Label keys and values can be no longer than 64
4924+
characters (Unicode codepoints), can only
4925+
contain lowercase letters, numeric characters,
4926+
underscores and dashes. International characters
4927+
are allowed.
4928+
See https://goo.gl/xmQnxf for more information
4929+
and examples of labels.
4930+
4931+
Raises:
4932+
ValueError: If `labels` is not the correct format.
4933+
"""
4934+
4935+
current_model_proto = self.get_model(version).gca_resource
4936+
copied_model_proto = current_model_proto.__class__(current_model_proto)
4937+
4938+
update_mask: List[str] = []
4939+
4940+
if version_description:
4941+
copied_model_proto.version_description = version_description
4942+
update_mask.append("version_description")
4943+
4944+
if labels:
4945+
utils.validate_labels(labels)
4946+
4947+
copied_model_proto.labels = labels
4948+
update_mask.append("labels")
4949+
4950+
update_mask = field_mask_pb2.FieldMask(paths=update_mask)
4951+
versioned_name = self._get_versioned_name(self.model_resource_name, version)
4952+
4953+
_LOGGER.info(f"Updating model {versioned_name}")
4954+
4955+
self.client.update_model(
4956+
model=copied_model_proto,
4957+
update_mask=update_mask,
4958+
)
4959+
4960+
_LOGGER.info(f"Completed updating model {versioned_name}")
4961+
49084962
def add_version_aliases(
49094963
self,
49104964
new_aliases: List[str],

tests/unit/aiplatform/test_models.py

+21
Original file line numberDiff line numberDiff line change
@@ -2714,6 +2714,27 @@ def test_delete_version(self, delete_model_version_mock, get_model_with_version)
27142714
)
27152715
)
27162716

2717+
@pytest.mark.usefixtures("get_model_mock")
2718+
def test_update_version(
2719+
self, update_model_mock, get_model_mock, get_model_with_version
2720+
):
2721+
my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION)
2722+
my_model.versioning_registry.update_version(
2723+
_TEST_VERSION_ALIAS_1,
2724+
version_description="update version",
2725+
labels=_TEST_LABEL,
2726+
)
2727+
2728+
model_to_update = _TEST_MODEL_OBJ_WITH_VERSION
2729+
model_to_update.version_description = "update version"
2730+
model_to_update.labels = _TEST_LABEL
2731+
2732+
update_mask = field_mask_pb2.FieldMask(paths=["version_description", "labels"])
2733+
2734+
update_model_mock.assert_called_once_with(
2735+
model=model_to_update, update_mask=update_mask
2736+
)
2737+
27172738
def test_add_versions(self, merge_version_aliases_mock, get_model_with_version):
27182739
my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION)
27192740
my_model.versioning_registry.add_version_aliases(

0 commit comments

Comments
 (0)