Skip to content

Commit b86a404

Browse files
lingyinwcopybara-github
authored andcommitted
feat: add remove_datapoints() to MatchingEngineIndex.
PiperOrigin-RevId: 582498139
1 parent 568d833 commit b86a404

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

google/cloud/aiplatform/compat/types/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
index_endpoint as index_endpoint_v1beta1,
5353
hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1,
5454
io as io_v1beta1,
55+
index_service as index_service_v1beta1,
5556
job_service as job_service_v1beta1,
5657
job_state as job_state_v1beta1,
5758
lineage_subgraph as lineage_subgraph_v1beta1,
@@ -275,6 +276,7 @@
275276
matching_engine_deployed_index_ref_v1beta1,
276277
index_v1beta1,
277278
index_endpoint_v1beta1,
279+
index_service_v1beta1,
278280
match_service_v1beta1,
279281
metadata_service_v1beta1,
280282
metadata_schema_v1beta1,

google/cloud/aiplatform/matching_engine/matching_engine_index.py

+41
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.protobuf import field_mask_pb2
2222
from google.cloud.aiplatform import base
2323
from google.cloud.aiplatform.compat.types import (
24+
index_service_v1beta1 as gca_index_service_v1beta1,
2425
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
2526
matching_engine_index as gca_matching_engine_index,
2627
encryption_spec as gca_encryption_spec,
@@ -661,6 +662,46 @@ def create_brute_force_index(
661662
encryption_spec_key_name=encryption_spec_key_name,
662663
)
663664

665+
def remove_datapoints(
666+
self,
667+
datapoint_ids: Sequence[str],
668+
) -> "MatchingEngineIndex":
669+
"""Remove datapoints for this index.
670+
671+
Args:
672+
datapoints_ids (Sequence[str]):
673+
Required. The list of datapoints ids to be deleted.
674+
675+
Returns:
676+
MatchingEngineIndex - Index resource object
677+
"""
678+
self.wait()
679+
680+
_LOGGER.log_action_start_against_resource(
681+
"Removing datapoints",
682+
"index",
683+
self,
684+
)
685+
686+
remove_lro = self.api_client.remove_datapoints(
687+
gca_index_service_v1beta1.RemoveDatapointsRequest(
688+
index=self.resource_name,
689+
datapoint_ids=datapoint_ids,
690+
)
691+
)
692+
693+
_LOGGER.log_action_started_against_resource_with_lro(
694+
"Remove datapoints", "index", self.__class__, remove_lro
695+
)
696+
697+
self._gca_resource = remove_lro.result(timeout=None)
698+
699+
_LOGGER.log_action_completed_against_resource(
700+
"index", "Removed datapoints", self
701+
)
702+
703+
return self
704+
664705

665706
_INDEX_UPDATE_METHOD_TO_ENUM_VALUE = {
666707
"STREAM_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.STREAM_UPDATE,

tests/unit/aiplatform/test_matching_engine_index.py

+30
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from google.cloud.aiplatform.compat.types import (
3636
index as gca_index,
3737
encryption_spec as gca_encryption_spec,
38+
index_service_v1beta1 as gca_index_service_v1beta1,
3839
)
3940
import constants as test_constants
4041

@@ -110,6 +111,9 @@
110111
# Encryption spec
111112
_TEST_ENCRYPTION_SPEC_KEY_NAME = "TEST_ENCRYPTION_SPEC"
112113

114+
# Streaming update
115+
_TEST_DATAPOINT_IDS = ("1", "2")
116+
113117

114118
def uuid_mock():
115119
return uuid.UUID(int=1)
@@ -192,6 +196,16 @@ def create_index_mock():
192196
yield create_index_mock
193197

194198

199+
@pytest.fixture
200+
def remove_datapoints_mock():
201+
with patch.object(
202+
index_service_client.IndexServiceClient, "remove_datapoints"
203+
) as remove_datapoints_mock:
204+
remove_datapoints_lro_mock = mock.Mock(operation.Operation)
205+
remove_datapoints_mock.return_value = remove_datapoints_lro_mock
206+
yield remove_datapoints_mock
207+
208+
195209
@pytest.mark.usefixtures("google_auth_mock")
196210
class TestMatchingEngineIndex:
197211
def setup_method(self):
@@ -414,3 +428,19 @@ def test_create_brute_force_index(
414428
index=expected,
415429
metadata=_TEST_REQUEST_METADATA,
416430
)
431+
432+
@pytest.mark.usefixtures("get_index_mock")
433+
def test_remove_datapoints(self, remove_datapoints_mock):
434+
aiplatform.init(project=_TEST_PROJECT)
435+
436+
my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
437+
my_index.remove_datapoints(
438+
datapoint_ids=_TEST_DATAPOINT_IDS,
439+
)
440+
441+
remove_datapoints_request = gca_index_service_v1beta1.RemoveDatapointsRequest(
442+
index=_TEST_INDEX_NAME,
443+
datapoint_ids=_TEST_DATAPOINT_IDS,
444+
)
445+
446+
remove_datapoints_mock.assert_called_once_with(remove_datapoints_request)

0 commit comments

Comments
 (0)