Skip to content

Commit c9f7119

Browse files
lingyinwcopybara-github
authored andcommitted
feat: support read_index_datapoints for private network.
PiperOrigin-RevId: 589281257
1 parent a8b24ad commit c9f7119

File tree

2 files changed

+131
-20
lines changed

2 files changed

+131
-20
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+86-20
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
)
217217
self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)
218218

219+
self._public_match_client = None
219220
if self.public_endpoint_domain_name:
220221
self._public_match_client = self._instantiate_public_match_client()
221222

@@ -518,6 +519,36 @@ def _instantiate_public_match_client(
518519
api_path_override=self.public_endpoint_domain_name,
519520
)
520521

522+
def _instantiate_private_match_service_stub(
523+
self,
524+
deployed_index_id: str,
525+
) -> match_service_pb2_grpc.MatchServiceStub:
526+
"""Helper method to instantiate private match service stub.
527+
Args:
528+
deployed_index_id (str):
529+
Required. The user specified ID of the
530+
DeployedIndex.
531+
Returns:
532+
stub (match_service_pb2_grpc.MatchServiceStub):
533+
Initialized match service stub.
534+
"""
535+
# Find the deployed index by id
536+
deployed_indexes = [
537+
deployed_index
538+
for deployed_index in self.deployed_indexes
539+
if deployed_index.id == deployed_index_id
540+
]
541+
542+
if not deployed_indexes:
543+
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
544+
545+
# Retrieve server ip from deployed index
546+
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
547+
548+
# Set up channel and stub
549+
channel = grpc.insecure_channel("{}:10000".format(server_ip))
550+
return match_service_pb2_grpc.MatchServiceStub(channel)
551+
521552
@property
522553
def public_endpoint_domain_name(self) -> Optional[str]:
523554
"""Public endpoint DNS name."""
@@ -1233,7 +1264,8 @@ def read_index_datapoints(
12331264
deployed_index_id: str,
12341265
ids: List[str] = [],
12351266
) -> List[gca_index_v1beta1.IndexDatapoint]:
1236-
"""Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint.
1267+
"""Reads the datapoints/vectors of the given IDs on the specified
1268+
deployed index which is deployed to public or private endpoint.
12371269
12381270
```
12391271
Example Usage:
@@ -1252,9 +1284,25 @@ def read_index_datapoints(
12521284
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
12531285
"""
12541286
if not self._public_match_client:
1255-
raise ValueError(
1256-
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
1287+
# Call private match service stub with BatchGetEmbeddings request
1288+
response = self._batch_get_embeddings(
1289+
deployed_index_id=deployed_index_id, ids=ids
12571290
)
1291+
return [
1292+
gca_index_v1beta1.IndexDatapoint(
1293+
datapoint_id=embedding.id,
1294+
feature_vector=embedding.float_val,
1295+
restricts=gca_index_v1beta1.IndexDatapoint.Restriction(
1296+
namespace=embedding.restricts.name,
1297+
allow_list=embedding.restricts.allow_tokens,
1298+
),
1299+
deny_list=embedding.restricts.deny_tokens,
1300+
crowding_attributes=gca_index_v1beta1.CrowdingEmbedding(
1301+
str(embedding.crowding_tag)
1302+
),
1303+
)
1304+
for embedding in response.embeddings
1305+
]
12581306

12591307
# Create the ReadIndexDatapoints request
12601308
read_index_datapoints_request = (
@@ -1273,6 +1321,38 @@ def read_index_datapoints(
12731321
# Wrap the results and return
12741322
return response.datapoints
12751323

1324+
def _batch_get_embeddings(
1325+
self,
1326+
*,
1327+
deployed_index_id: str,
1328+
ids: List[str] = [],
1329+
) -> List[List[match_service_pb2.Embedding]]:
1330+
"""
1331+
Reads the datapoints/vectors of the given IDs on the specified index
1332+
which is deployed to private endpoint.
1333+
1334+
Args:
1335+
deployed_index_id (str):
1336+
Required. The ID of the DeployedIndex to match the queries against.
1337+
ids (List[str]):
1338+
Required. IDs of the datapoints to be searched for.
1339+
Returns:
1340+
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
1341+
"""
1342+
stub = self._instantiate_private_match_service_stub(
1343+
deployed_index_id=deployed_index_id
1344+
)
1345+
1346+
# Create the batch get embeddings request
1347+
batch_request = match_service_pb2.BatchGetEmbeddingsRequest()
1348+
batch_request.deployed_index_id = deployed_index_id
1349+
1350+
for id in ids:
1351+
batch_request.id.append(id)
1352+
response = stub.BatchGetEmbeddings(batch_request)
1353+
1354+
return response.embeddings
1355+
12761356
def match(
12771357
self,
12781358
deployed_index_id: str,
@@ -1310,23 +1390,9 @@ def match(
13101390
Returns:
13111391
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
13121392
"""
1313-
1314-
# Find the deployed index by id
1315-
deployed_indexes = [
1316-
deployed_index
1317-
for deployed_index in self.deployed_indexes
1318-
if deployed_index.id == deployed_index_id
1319-
]
1320-
1321-
if not deployed_indexes:
1322-
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
1323-
1324-
# Retrieve server ip from deployed index
1325-
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
1326-
1327-
# Set up channel and stub
1328-
channel = grpc.insecure_channel("{}:10000".format(server_ip))
1329-
stub = match_service_pb2_grpc.MatchServiceStub(channel)
1393+
stub = self._instantiate_private_match_service_stub(
1394+
deployed_index_id=deployed_index_id
1395+
)
13301396

13311397
# Create the batch match request
13321398
batch_request = match_service_pb2.BatchMatchRequest()

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+45
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,31 @@ def index_endpoint_match_queries_mock():
493493
yield index_endpoint_match_queries_mock
494494

495495

496+
@pytest.fixture
497+
def index_endpoint_batch_get_embeddings_mock():
498+
with patch.object(
499+
grpc._channel._UnaryUnaryMultiCallable,
500+
"__call__",
501+
) as index_endpoint_batch_get_embeddings_mock:
502+
index_endpoint_batch_get_embeddings_mock.return_value = (
503+
match_service_pb2.BatchGetEmbeddingsResponse(
504+
embeddings=[
505+
match_service_pb2.Embedding(
506+
id="1",
507+
float_val=[0.1, 0.2, 0.3],
508+
crowding_attribute=1,
509+
),
510+
match_service_pb2.Embedding(
511+
id="2",
512+
float_val=[0.5, 0.2, 0.3],
513+
crowding_attribute=1,
514+
),
515+
]
516+
)
517+
)
518+
yield index_endpoint_batch_get_embeddings_mock
519+
520+
496521
@pytest.fixture
497522
def index_public_endpoint_match_queries_mock():
498523
with patch.object(
@@ -1204,3 +1229,23 @@ def test_index_public_endpoint_read_index_datapoints(
12041229
index_public_endpoint_read_index_datapoints_mock.assert_called_with(
12051230
read_index_datapoints_request
12061231
)
1232+
1233+
@pytest.mark.usefixtures("get_index_endpoint_mock")
1234+
def test_index_endpoint_batch_get_embeddings(
1235+
self, index_endpoint_batch_get_embeddings_mock
1236+
):
1237+
aiplatform.init(project=_TEST_PROJECT)
1238+
1239+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1240+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1241+
)
1242+
1243+
my_index_endpoint._batch_get_embeddings(
1244+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=["1", "2"]
1245+
)
1246+
1247+
batch_request = match_service_pb2.BatchGetEmbeddingsRequest(
1248+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
1249+
)
1250+
1251+
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

0 commit comments

Comments
 (0)