Skip to content

Commit 7ca484d

Browse files
lingyinwcopybara-github
authored andcommitted
feat: add upsert_datapoints() to MatchingEngineIndex to support streaming update index.
PiperOrigin-RevId: 583089201
1 parent ba2fb39 commit 7ca484d

File tree

4 files changed

+105
-14
lines changed

4 files changed

+105
-14
lines changed

google/cloud/aiplatform/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1
8989
types.index = types.index_v1beta1
9090
types.index_endpoint = types.index_endpoint_v1beta1
91+
types.index_service = types.index_service_v1beta1
9192
types.io = types.io_v1beta1
9293
types.job_service = types.job_service_v1beta1
9394
types.job_state = types.job_state_v1beta1
@@ -189,6 +190,7 @@
189190
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1
190191
types.index = types.index_v1
191192
types.index_endpoint = types.index_endpoint_v1
193+
types.index_service = types.index_service_v1
192194
types.io = types.io_v1
193195
types.job_service = types.job_service_v1
194196
types.job_state = types.job_state_v1

google/cloud/aiplatform/compat/types/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
hyperparameter_tuning_job as hyperparameter_tuning_job_v1,
126126
index as index_v1,
127127
index_endpoint as index_endpoint_v1,
128+
index_service as index_service_v1,
128129
io as io_v1,
129130
job_service as job_service_v1,
130131
job_state as job_state_v1,
@@ -204,6 +205,7 @@
204205
matching_engine_deployed_index_ref_v1,
205206
index_v1,
206207
index_endpoint_v1,
208+
index_service_v1,
207209
metadata_service_v1,
208210
metadata_schema_v1,
209211
metadata_store_v1,

google/cloud/aiplatform/matching_engine/matching_engine_index.py

+40-9
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from google.protobuf import field_mask_pb2
2222
from google.cloud.aiplatform import base
2323
from google.cloud.aiplatform.compat.types import (
24-
index_service_v1beta1 as gca_index_service_v1beta1,
24+
index_service as gca_index_service,
2525
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
2626
matching_engine_index as gca_matching_engine_index,
2727
encryption_spec as gca_encryption_spec,
@@ -665,6 +665,42 @@ def create_brute_force_index(
665665
encryption_spec_key_name=encryption_spec_key_name,
666666
)
667667

668+
def upsert_datapoints(
669+
self,
670+
datapoints: Sequence[gca_matching_engine_index.IndexDatapoint],
671+
) -> "MatchingEngineIndex":
672+
"""Upsert datapoints to this index.
673+
674+
Args:
675+
datapoints (Sequence[gca_matching_engine_index.IndexDatapoint]):
676+
Required. Datapoints to be upserted to this index.
677+
678+
Returns:
679+
MatchingEngineIndex - Index resource object
680+
681+
"""
682+
683+
self.wait()
684+
685+
_LOGGER.log_action_start_against_resource(
686+
"Upserting datapoints",
687+
"index",
688+
self,
689+
)
690+
691+
self.api_client.upsert_datapoints(
692+
gca_index_service.UpsertDatapointsRequest(
693+
index=self.resource_name,
694+
datapoints=datapoints,
695+
)
696+
)
697+
698+
_LOGGER.log_action_completed_against_resource(
699+
"index", "Upserted datapoints", self
700+
)
701+
702+
return self
703+
668704
def remove_datapoints(
669705
self,
670706
datapoint_ids: Sequence[str],
@@ -678,6 +714,7 @@ def remove_datapoints(
678714
Returns:
679715
MatchingEngineIndex - Index resource object
680716
"""
717+
681718
self.wait()
682719

683720
_LOGGER.log_action_start_against_resource(
@@ -686,19 +723,13 @@ def remove_datapoints(
686723
self,
687724
)
688725

689-
remove_lro = self.api_client.remove_datapoints(
690-
gca_index_service_v1beta1.RemoveDatapointsRequest(
726+
self.api_client.remove_datapoints(
727+
gca_index_service.RemoveDatapointsRequest(
691728
index=self.resource_name,
692729
datapoint_ids=datapoint_ids,
693730
)
694731
)
695732

696-
_LOGGER.log_action_started_against_resource_with_lro(
697-
"Remove datapoints", "index", self.__class__, remove_lro
698-
)
699-
700-
self._gca_resource = remove_lro.result(timeout=None)
701-
702733
_LOGGER.log_action_completed_against_resource(
703734
"index", "Removed datapoints", self
704735
)

tests/unit/aiplatform/test_matching_engine_index.py

+61-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from google.cloud.aiplatform.compat.types import (
3636
index as gca_index,
3737
encryption_spec as gca_encryption_spec,
38-
index_service_v1beta1 as gca_index_service_v1beta1,
38+
index_service as gca_index_service,
3939
)
4040
import constants as test_constants
4141

@@ -111,8 +111,42 @@
111111
# Encryption spec
112112
_TEST_ENCRYPTION_SPEC_KEY_NAME = "TEST_ENCRYPTION_SPEC"
113113

114-
# Streaming update
115114
_TEST_DATAPOINT_IDS = ("1", "2")
115+
_TEST_DATAPOINT_1 = gca_index.IndexDatapoint(
116+
datapoint_id="0",
117+
feature_vector=[0.00526886899, -0.0198396724],
118+
restricts=[
119+
gca_index.IndexDatapoint.Restriction(namespace="Color", allow_list=["red"])
120+
],
121+
numeric_restricts=[
122+
gca_index.IndexDatapoint.NumericRestriction(
123+
namespace="cost",
124+
value_int=1,
125+
)
126+
],
127+
)
128+
_TEST_DATAPOINT_2 = gca_index.IndexDatapoint(
129+
datapoint_id="1",
130+
feature_vector=[0.00526886899, -0.0198396724],
131+
numeric_restricts=[
132+
gca_index.IndexDatapoint.NumericRestriction(
133+
namespace="cost",
134+
value_double=0.1,
135+
)
136+
],
137+
crowding_tag=gca_index.IndexDatapoint.CrowdingTag(crowding_attribute="crowding"),
138+
)
139+
_TEST_DATAPOINT_3 = gca_index.IndexDatapoint(
140+
datapoint_id="2",
141+
feature_vector=[0.00526886899, -0.0198396724],
142+
numeric_restricts=[
143+
gca_index.IndexDatapoint.NumericRestriction(
144+
namespace="cost",
145+
value_float=1.1,
146+
)
147+
],
148+
)
149+
_TEST_DATAPOINTS = (_TEST_DATAPOINT_1, _TEST_DATAPOINT_2, _TEST_DATAPOINT_3)
116150

117151

118152
def uuid_mock():
@@ -196,13 +230,19 @@ def create_index_mock():
196230
yield create_index_mock
197231

198232

233+
@pytest.fixture
234+
def upsert_datapoints_mock():
235+
with patch.object(
236+
index_service_client.IndexServiceClient, "upsert_datapoints"
237+
) as upsert_datapoints_mock:
238+
yield upsert_datapoints_mock
239+
240+
199241
@pytest.fixture
200242
def remove_datapoints_mock():
201243
with patch.object(
202244
index_service_client.IndexServiceClient, "remove_datapoints"
203245
) as remove_datapoints_mock:
204-
remove_datapoints_lro_mock = mock.Mock(operation.Operation)
205-
remove_datapoints_mock.return_value = remove_datapoints_lro_mock
206246
yield remove_datapoints_mock
207247

208248

@@ -509,6 +549,22 @@ def test_create_brute_force_index_backward_compatibility(self, create_index_mock
509549
metadata=_TEST_REQUEST_METADATA,
510550
)
511551

552+
@pytest.mark.usefixtures("get_index_mock")
553+
def test_upsert_datapoints(self, upsert_datapoints_mock):
554+
aiplatform.init(project=_TEST_PROJECT)
555+
556+
my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
557+
my_index.upsert_datapoints(
558+
datapoints=_TEST_DATAPOINTS,
559+
)
560+
561+
upsert_datapoints_request = gca_index_service.UpsertDatapointsRequest(
562+
index=_TEST_INDEX_NAME,
563+
datapoints=_TEST_DATAPOINTS,
564+
)
565+
566+
upsert_datapoints_mock.assert_called_once_with(upsert_datapoints_request)
567+
512568
@pytest.mark.usefixtures("get_index_mock")
513569
def test_remove_datapoints(self, remove_datapoints_mock):
514570
aiplatform.init(project=_TEST_PROJECT)
@@ -518,7 +574,7 @@ def test_remove_datapoints(self, remove_datapoints_mock):
518574
datapoint_ids=_TEST_DATAPOINT_IDS,
519575
)
520576

521-
remove_datapoints_request = gca_index_service_v1beta1.RemoveDatapointsRequest(
577+
remove_datapoints_request = gca_index_service.RemoveDatapointsRequest(
522578
index=_TEST_INDEX_NAME,
523579
datapoint_ids=_TEST_DATAPOINT_IDS,
524580
)

0 commit comments

Comments
 (0)