Skip to content

Commit 3d8835e

Browse files
lingyinwcopybara-github
authored andcommitted
fix: read_index_endpoint private endpoint support.
PiperOrigin-RevId: 589880026
1 parent 0a4d772 commit 3d8835e

File tree

2 files changed

+71
-15
lines changed

2 files changed

+71
-15
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -1285,24 +1285,32 @@ def read_index_datapoints(
12851285
"""
12861286
if not self._public_match_client:
12871287
# Call private match service stub with BatchGetEmbeddings request
1288-
response = self._batch_get_embeddings(
1288+
embeddings = self._batch_get_embeddings(
12891289
deployed_index_id=deployed_index_id, ids=ids
12901290
)
1291-
return [
1292-
gca_index_v1beta1.IndexDatapoint(
1291+
1292+
response = []
1293+
for embedding in embeddings:
1294+
index_datapoint = gca_index_v1beta1.IndexDatapoint(
12931295
datapoint_id=embedding.id,
12941296
feature_vector=embedding.float_val,
1295-
restricts=gca_index_v1beta1.IndexDatapoint.Restriction(
1296-
namespace=embedding.restricts.name,
1297-
allow_list=embedding.restricts.allow_tokens,
1298-
),
1299-
deny_list=embedding.restricts.deny_tokens,
1300-
crowding_attributes=gca_index_v1beta1.CrowdingEmbedding(
1301-
str(embedding.crowding_tag)
1302-
),
1297+
restricts=[
1298+
gca_index_v1beta1.IndexDatapoint.Restriction(
1299+
namespace=restrict.name,
1300+
allow_list=restrict.allow_tokens,
1301+
deny_list=restrict.deny_tokens,
1302+
)
1303+
for restrict in embedding.restricts
1304+
],
13031305
)
1304-
for embedding in response.embeddings
1305-
]
1306+
if embedding.crowding_attribute:
1307+
index_datapoint.crowding_tag = (
1308+
gca_index_v1beta1.IndexDatapoint.CrowdingTag(
1309+
crowding_attribute=str(embedding.crowding_attribute)
1310+
)
1311+
)
1312+
response.append(index_datapoint)
1313+
return response
13061314

13071315
# Create the ReadIndexDatapoints request
13081316
read_index_datapoints_request = (
@@ -1326,7 +1334,7 @@ def _batch_get_embeddings(
13261334
*,
13271335
deployed_index_id: str,
13281336
ids: List[str] = [],
1329-
) -> List[List[match_service_pb2.Embedding]]:
1337+
) -> List[match_service_pb2.Embedding]:
13301338
"""
13311339
Reads the datapoints/vectors of the given IDs on the specified index
13321340
which is deployed to private endpoint.

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,26 @@
246246
_TEST_RETURN_FULL_DATAPOINT = True
247247
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
248248
_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"]
249+
_TEST_READ_INDEX_DATAPOINTS_RESPONSE = [
250+
gca_index_v1beta1.IndexDatapoint(
251+
datapoint_id="1",
252+
feature_vector=[0.1, 0.2, 0.3],
253+
restricts=[
254+
gca_index_v1beta1.IndexDatapoint.Restriction(
255+
namespace="class",
256+
allow_list=["token_1"],
257+
deny_list=["token_2"],
258+
)
259+
],
260+
),
261+
gca_index_v1beta1.IndexDatapoint(
262+
datapoint_id="2",
263+
feature_vector=[0.5, 0.2, 0.3],
264+
crowding_tag=gca_index_v1beta1.IndexDatapoint.CrowdingTag(
265+
crowding_attribute="1"
266+
),
267+
),
268+
]
249269

250270

251271
def uuid_mock():
@@ -505,7 +525,13 @@ def index_endpoint_batch_get_embeddings_mock():
505525
match_service_pb2.Embedding(
506526
id="1",
507527
float_val=[0.1, 0.2, 0.3],
508-
crowding_attribute=1,
528+
restricts=[
529+
match_service_pb2.Namespace(
530+
name="class",
531+
allow_tokens=["token_1"],
532+
deny_tokens=["token_2"],
533+
)
534+
],
509535
),
510536
match_service_pb2.Embedding(
511537
id="2",
@@ -1249,3 +1275,25 @@ def test_index_endpoint_batch_get_embeddings(
12491275
)
12501276

12511277
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)
1278+
1279+
@pytest.mark.usefixtures("get_index_endpoint_mock")
1280+
def test_index_endpoint_find_neighbors_for_private(
1281+
self, index_endpoint_batch_get_embeddings_mock
1282+
):
1283+
aiplatform.init(project=_TEST_PROJECT)
1284+
1285+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1286+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1287+
)
1288+
1289+
response = my_index_endpoint.read_index_datapoints(
1290+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=["1", "2"]
1291+
)
1292+
1293+
batch_request = match_service_pb2.BatchGetEmbeddingsRequest(
1294+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
1295+
)
1296+
1297+
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)
1298+
1299+
assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE

0 commit comments

Comments
 (0)