@@ -216,6 +216,7 @@ def __init__(
216
216
)
217
217
self ._gca_resource = self ._get_gca_resource (resource_name = index_endpoint_name )
218
218
219
+ self ._public_match_client = None
219
220
if self .public_endpoint_domain_name :
220
221
self ._public_match_client = self ._instantiate_public_match_client ()
221
222
@@ -518,6 +519,36 @@ def _instantiate_public_match_client(
518
519
api_path_override = self .public_endpoint_domain_name ,
519
520
)
520
521
522
+ def _instantiate_private_match_service_stub (
523
+ self ,
524
+ deployed_index_id : str ,
525
+ ) -> match_service_pb2_grpc .MatchServiceStub :
526
+ """Helper method to instantiate private match service stub.
527
+ Args:
528
+ deployed_index_id (str):
529
+ Required. The user specified ID of the
530
+ DeployedIndex.
531
+ Returns:
532
+ stub (match_service_pb2_grpc.MatchServiceStub):
533
+ Initialized match service stub.
534
+ """
535
+ # Find the deployed index by id
536
+ deployed_indexes = [
537
+ deployed_index
538
+ for deployed_index in self .deployed_indexes
539
+ if deployed_index .id == deployed_index_id
540
+ ]
541
+
542
+ if not deployed_indexes :
543
+ raise RuntimeError (f"No deployed index with id '{ deployed_index_id } ' found" )
544
+
545
+ # Retrieve server ip from deployed index
546
+ server_ip = deployed_indexes [0 ].private_endpoints .match_grpc_address
547
+
548
+ # Set up channel and stub
549
+ channel = grpc .insecure_channel ("{}:10000" .format (server_ip ))
550
+ return match_service_pb2_grpc .MatchServiceStub (channel )
551
+
521
552
@property
522
553
def public_endpoint_domain_name (self ) -> Optional [str ]:
523
554
"""Public endpoint DNS name."""
@@ -1233,7 +1264,8 @@ def read_index_datapoints(
1233
1264
deployed_index_id : str ,
1234
1265
ids : List [str ] = [],
1235
1266
) -> List [gca_index_v1beta1 .IndexDatapoint ]:
1236
- """Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint.
1267
+ """Reads the datapoints/vectors of the given IDs on the specified
1268
+ deployed index which is deployed to public or private endpoint.
1237
1269
1238
1270
```
1239
1271
Example Usage:
@@ -1252,9 +1284,25 @@ def read_index_datapoints(
1252
1284
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
1253
1285
"""
1254
1286
if not self ._public_match_client :
1255
- raise ValueError (
1256
- "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
1287
+ # Call private match service stub with BatchGetEmbeddings request
1288
+ response = self ._batch_get_embeddings (
1289
+ deployed_index_id = deployed_index_id , ids = ids
1257
1290
)
1291
+ return [
1292
+ gca_index_v1beta1 .IndexDatapoint (
1293
+ datapoint_id = embedding .id ,
1294
+ feature_vector = embedding .float_val ,
1295
+ restricts = gca_index_v1beta1 .IndexDatapoint .Restriction (
1296
+ namespace = embedding .restricts .name ,
1297
+ allow_list = embedding .restricts .allow_tokens ,
1298
+ ),
1299
+ deny_list = embedding .restricts .deny_tokens ,
1300
+ crowding_attributes = gca_index_v1beta1 .CrowdingEmbedding (
1301
+ str (embedding .crowding_tag )
1302
+ ),
1303
+ )
1304
+ for embedding in response .embeddings
1305
+ ]
1258
1306
1259
1307
# Create the ReadIndexDatapoints request
1260
1308
read_index_datapoints_request = (
@@ -1273,6 +1321,38 @@ def read_index_datapoints(
1273
1321
# Wrap the results and return
1274
1322
return response .datapoints
1275
1323
1324
+ def _batch_get_embeddings (
1325
+ self ,
1326
+ * ,
1327
+ deployed_index_id : str ,
1328
+ ids : List [str ] = [],
1329
+ ) -> List [List [match_service_pb2 .Embedding ]]:
1330
+ """
1331
+ Reads the datapoints/vectors of the given IDs on the specified index
1332
+ which is deployed to private endpoint.
1333
+
1334
+ Args:
1335
+ deployed_index_id (str):
1336
+ Required. The ID of the DeployedIndex to match the queries against.
1337
+ ids (List[str]):
1338
+ Required. IDs of the datapoints to be searched for.
1339
+ Returns:
1340
+ List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
1341
+ """
1342
+ stub = self ._instantiate_private_match_service_stub (
1343
+ deployed_index_id = deployed_index_id
1344
+ )
1345
+
1346
+ # Create the batch get embeddings request
1347
+ batch_request = match_service_pb2 .BatchGetEmbeddingsRequest ()
1348
+ batch_request .deployed_index_id = deployed_index_id
1349
+
1350
+ for id in ids :
1351
+ batch_request .id .append (id )
1352
+ response = stub .BatchGetEmbeddings (batch_request )
1353
+
1354
+ return response .embeddings
1355
+
1276
1356
def match (
1277
1357
self ,
1278
1358
deployed_index_id : str ,
@@ -1310,23 +1390,9 @@ def match(
1310
1390
Returns:
1311
1391
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
1312
1392
"""
1313
-
1314
- # Find the deployed index by id
1315
- deployed_indexes = [
1316
- deployed_index
1317
- for deployed_index in self .deployed_indexes
1318
- if deployed_index .id == deployed_index_id
1319
- ]
1320
-
1321
- if not deployed_indexes :
1322
- raise RuntimeError (f"No deployed index with id '{ deployed_index_id } ' found" )
1323
-
1324
- # Retrieve server ip from deployed index
1325
- server_ip = deployed_indexes [0 ].private_endpoints .match_grpc_address
1326
-
1327
- # Set up channel and stub
1328
- channel = grpc .insecure_channel ("{}:10000" .format (server_ip ))
1329
- stub = match_service_pb2_grpc .MatchServiceStub (channel )
1393
+ stub = self ._instantiate_private_match_service_stub (
1394
+ deployed_index_id = deployed_index_id
1395
+ )
1330
1396
1331
1397
# Create the batch match request
1332
1398
batch_request = match_service_pb2 .BatchMatchRequest ()
0 commit comments