Skip to content

Commit 81f6a25

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Add update_mask to MatchingEngineIndex upsert_datapoints() to support dynamic metadata update.
PiperOrigin-RevId: 609006985
1 parent 09d1946 commit 81f6a25

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index.py

+12
Original file line numberDiff line numberDiff line change
@@ -692,12 +692,21 @@ def create_brute_force_index(
692692
def upsert_datapoints(
693693
self,
694694
datapoints: Sequence[gca_matching_engine_index.IndexDatapoint],
695+
update_mask: Optional[Sequence[str]] = None,
695696
) -> "MatchingEngineIndex":
696697
"""Upsert datapoints to this index.
697698
698699
Args:
699700
datapoints (Sequence[gca_matching_engine_index.IndexDatapoint]):
700701
Required. Datapoints to be upserted to this index.
702+
update_mask (Sequence[str]):
703+
Optional. Update mask is used to specify the fields to be
704+
overwritten in the datapoints by the update. The fields
705+
specified in the update_mask are relative to each IndexDatapoint
706+
inside datapoints, not the full request.
707+
Updatable fields:
708+
Use `all_restricts` to update both `restricts` and
709+
`numeric_restricts`.
701710
702711
Returns:
703712
MatchingEngineIndex - Index resource object
@@ -716,6 +725,9 @@ def upsert_datapoints(
716725
gca_index_service.UpsertDatapointsRequest(
717726
index=self.resource_name,
718727
datapoints=datapoints,
728+
update_mask=(
729+
field_mask_pb2.FieldMask(paths=update_mask) if update_mask else None
730+
),
719731
)
720732
)
721733

tests/unit/aiplatform/test_matching_engine_index.py

+19
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
)
149149
_TEST_DATAPOINTS = (_TEST_DATAPOINT_1, _TEST_DATAPOINT_2, _TEST_DATAPOINT_3)
150150
_TEST_TIMEOUT = 1800.0
151+
_TEST_UPDATE_MASK = ["all_restricts"]
151152

152153

153154
def uuid_mock():
@@ -706,6 +707,24 @@ def test_upsert_datapoints(self, upsert_datapoints_mock):
706707

707708
upsert_datapoints_mock.assert_called_once_with(upsert_datapoints_request)
708709

710+
@pytest.mark.usefixtures("get_index_mock")
711+
def test_upsert_datapoints_dynamic_metadata_update(self, upsert_datapoints_mock):
712+
aiplatform.init(project=_TEST_PROJECT)
713+
714+
my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
715+
my_index.upsert_datapoints(
716+
datapoints=_TEST_DATAPOINTS,
717+
update_mask=_TEST_UPDATE_MASK,
718+
)
719+
720+
upsert_datapoints_request = gca_index_service.UpsertDatapointsRequest(
721+
index=_TEST_INDEX_NAME,
722+
datapoints=_TEST_DATAPOINTS,
723+
update_mask=field_mask_pb2.FieldMask(paths=_TEST_UPDATE_MASK),
724+
)
725+
726+
upsert_datapoints_mock.assert_called_once_with(upsert_datapoints_request)
727+
709728
@pytest.mark.usefixtures("get_index_mock")
710729
def test_remove_datapoints(self, remove_datapoints_mock):
711730
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)