@@ -1387,6 +1387,7 @@ def find_neighbors(
1387
1387
return_full_datapoint : bool = False ,
1388
1388
numeric_filter : Optional [List [NumericNamespace ]] = None ,
1389
1389
embedding_ids : Optional [List [str ]] = None ,
1390
+ signed_jwt : Optional [str ] = None ,
1390
1391
) -> List [List [MatchNeighbor ]]:
1391
1392
"""Retrieves nearest neighbors for the given embedding queries on the
1392
1393
specified deployed index which is deployed to either public or private
@@ -1456,6 +1457,9 @@ def find_neighbors(
1456
1457
`embedding_ids` to lookup embedding values from dataset, if embedding
1457
1458
with `embedding_ids` exists in the dataset, do nearest neighbor search.
1458
1459
1460
+ signed_jwt (str):
1461
+ Optional. A signed JWT for accessing the private endpoint.
1462
+
1459
1463
Returns:
1460
1464
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
1461
1465
"""
@@ -1471,6 +1475,7 @@ def find_neighbors(
1471
1475
approx_num_neighbors = approx_num_neighbors ,
1472
1476
fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1473
1477
numeric_filter = numeric_filter ,
1478
+ signed_jwt = signed_jwt ,
1474
1479
)
1475
1480
1476
1481
# Create the FindNeighbors request
@@ -1570,6 +1575,7 @@ def read_index_datapoints(
1570
1575
* ,
1571
1576
deployed_index_id : str ,
1572
1577
ids : List [str ] = [],
1578
+ signed_jwt : Optional [str ] = None ,
1573
1579
) -> List [gca_index_v1beta1 .IndexDatapoint ]:
1574
1580
"""Reads the datapoints/vectors of the given IDs on the specified
1575
1581
deployed index which is deployed to public or private endpoint.
@@ -1587,6 +1593,8 @@ def read_index_datapoints(
1587
1593
Required. The ID of the DeployedIndex to match the queries against.
1588
1594
ids (List[str]):
1589
1595
Required. IDs of the datapoints to be searched for.
1596
+ signed_jwt (str):
1597
+ Optional. A signed JWT for accessing the private endpoint.
1590
1598
Returns:
1591
1599
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
1592
1600
"""
@@ -1595,6 +1603,7 @@ def read_index_datapoints(
1595
1603
embeddings = self ._batch_get_embeddings (
1596
1604
deployed_index_id = deployed_index_id ,
1597
1605
ids = ids ,
1606
+ signed_jwt = signed_jwt ,
1598
1607
)
1599
1608
1600
1609
response = []
@@ -1641,6 +1650,7 @@ def _batch_get_embeddings(
1641
1650
* ,
1642
1651
deployed_index_id : str ,
1643
1652
ids : List [str ] = [],
1653
+ signed_jwt : Optional [str ] = None ,
1644
1654
) -> List [match_service_pb2 .Embedding ]:
1645
1655
"""
1646
1656
Reads the datapoints/vectors of the given IDs on the specified index
@@ -1651,6 +1661,8 @@ def _batch_get_embeddings(
1651
1661
Required. The ID of the DeployedIndex to match the queries against.
1652
1662
ids (List[str]):
1653
1663
Required. IDs of the datapoints to be searched for.
1664
+ signed_jwt:
1665
+ Optional. A signed JWT for accessing the private endpoint.
1654
1666
Returns:
1655
1667
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
1656
1668
"""
@@ -1665,7 +1677,10 @@ def _batch_get_embeddings(
1665
1677
1666
1678
for id in ids :
1667
1679
batch_request .id .append (id )
1668
- response = stub .BatchGetEmbeddings (batch_request )
1680
+ metadata = None
1681
+ if signed_jwt :
1682
+ metadata = (("authorization" , f"Bearer: { signed_jwt } " ),)
1683
+ response = stub .BatchGetEmbeddings (batch_request , metadata = metadata )
1669
1684
1670
1685
return response .embeddings
1671
1686
@@ -1680,6 +1695,7 @@ def match(
1680
1695
fraction_leaf_nodes_to_search_override : Optional [float ] = None ,
1681
1696
low_level_batch_size : int = 0 ,
1682
1697
numeric_filter : Optional [List [NumericNamespace ]] = None ,
1698
+ signed_jwt : Optional [str ] = None ,
1683
1699
) -> List [List [MatchNeighbor ]]:
1684
1700
"""Retrieves nearest neighbors for the given embedding queries on the
1685
1701
specified deployed index for private endpoint only.
@@ -1729,6 +1745,8 @@ def match(
1729
1745
results. For example:
1730
1746
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
1731
1747
will match datapoints that its cost is greater than 5.
1748
+ signed_jwt (str):
1749
+ Optional. A signed JWT for accessing the private endpoint.
1732
1750
1733
1751
Returns:
1734
1752
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
@@ -1809,7 +1827,10 @@ def match(
1809
1827
batch_request .requests .append (batch_request_for_index )
1810
1828
1811
1829
# Perform the request
1812
- response = stub .BatchMatch (batch_request )
1830
+ metadata = None
1831
+ if signed_jwt :
1832
+ metadata = (("authorization" , f"Bearer: { signed_jwt } " ),)
1833
+ response = stub .BatchMatch (batch_request , metadata = metadata )
1813
1834
1814
1835
# Wrap the results in MatchNeighbor objects and return
1815
1836
match_neighbors_response = []
0 commit comments