Skip to content

Commit 5025d03

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for self-signed JWT for queries on private endpoints
PiperOrigin-RevId: 689941402
1 parent 91c2120 commit 5025d03

10 files changed

+504
-11
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ def find_neighbors(
13871387
return_full_datapoint: bool = False,
13881388
numeric_filter: Optional[List[NumericNamespace]] = None,
13891389
embedding_ids: Optional[List[str]] = None,
1390+
signed_jwt: Optional[str] = None,
13901391
) -> List[List[MatchNeighbor]]:
13911392
"""Retrieves nearest neighbors for the given embedding queries on the
13921393
specified deployed index which is deployed to either public or private
@@ -1456,6 +1457,9 @@ def find_neighbors(
14561457
`embedding_ids` to lookup embedding values from dataset, if embedding
14571458
with `embedding_ids` exists in the dataset, do nearest neighbor search.
14581459
1460+
signed_jwt (str):
1461+
Optional. A signed JWT for accessing the private endpoint.
1462+
14591463
Returns:
14601464
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
14611465
"""
@@ -1471,6 +1475,7 @@ def find_neighbors(
14711475
approx_num_neighbors=approx_num_neighbors,
14721476
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
14731477
numeric_filter=numeric_filter,
1478+
signed_jwt=signed_jwt,
14741479
)
14751480

14761481
# Create the FindNeighbors request
@@ -1570,6 +1575,7 @@ def read_index_datapoints(
15701575
*,
15711576
deployed_index_id: str,
15721577
ids: List[str] = [],
1578+
signed_jwt: Optional[str] = None,
15731579
) -> List[gca_index_v1beta1.IndexDatapoint]:
15741580
"""Reads the datapoints/vectors of the given IDs on the specified
15751581
deployed index which is deployed to public or private endpoint.
@@ -1587,6 +1593,8 @@ def read_index_datapoints(
15871593
Required. The ID of the DeployedIndex to match the queries against.
15881594
ids (List[str]):
15891595
Required. IDs of the datapoints to be searched for.
1596+
signed_jwt (str):
1597+
Optional. A signed JWT for accessing the private endpoint.
15901598
Returns:
15911599
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
15921600
"""
@@ -1595,6 +1603,7 @@ def read_index_datapoints(
15951603
embeddings = self._batch_get_embeddings(
15961604
deployed_index_id=deployed_index_id,
15971605
ids=ids,
1606+
signed_jwt=signed_jwt,
15981607
)
15991608

16001609
response = []
@@ -1641,6 +1650,7 @@ def _batch_get_embeddings(
16411650
*,
16421651
deployed_index_id: str,
16431652
ids: List[str] = [],
1653+
signed_jwt: Optional[str] = None,
16441654
) -> List[match_service_pb2.Embedding]:
16451655
"""
16461656
Reads the datapoints/vectors of the given IDs on the specified index
@@ -1651,6 +1661,8 @@ def _batch_get_embeddings(
16511661
Required. The ID of the DeployedIndex to match the queries against.
16521662
ids (List[str]):
16531663
Required. IDs of the datapoints to be searched for.
1664+
signed_jwt:
1665+
Optional. A signed JWT for accessing the private endpoint.
16541666
Returns:
16551667
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
16561668
"""
@@ -1665,7 +1677,10 @@ def _batch_get_embeddings(
16651677

16661678
for id in ids:
16671679
batch_request.id.append(id)
1668-
response = stub.BatchGetEmbeddings(batch_request)
1680+
metadata = None
1681+
if signed_jwt:
1682+
metadata = (("authorization", f"Bearer: {signed_jwt}"),)
1683+
response = stub.BatchGetEmbeddings(batch_request, metadata=metadata)
16691684

16701685
return response.embeddings
16711686

@@ -1680,6 +1695,7 @@ def match(
16801695
fraction_leaf_nodes_to_search_override: Optional[float] = None,
16811696
low_level_batch_size: int = 0,
16821697
numeric_filter: Optional[List[NumericNamespace]] = None,
1698+
signed_jwt: Optional[str] = None,
16831699
) -> List[List[MatchNeighbor]]:
16841700
"""Retrieves nearest neighbors for the given embedding queries on the
16851701
specified deployed index for private endpoint only.
@@ -1729,6 +1745,8 @@ def match(
17291745
results. For example:
17301746
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
17311747
will match datapoints that its cost is greater than 5.
1748+
signed_jwt (str):
1749+
Optional. A signed JWT for accessing the private endpoint.
17321750
17331751
Returns:
17341752
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
@@ -1809,7 +1827,10 @@ def match(
18091827
batch_request.requests.append(batch_request_for_index)
18101828

18111829
# Perform the request
1812-
response = stub.BatchMatch(batch_request)
1830+
metadata = None
1831+
if signed_jwt:
1832+
metadata = (("authorization", f"Bearer: {signed_jwt}"),)
1833+
response = stub.BatchMatch(batch_request, metadata=metadata)
18131834

18141835
# Wrap the results in MatchNeighbor objects and return
18151836
match_neighbors_response = []

samples/model-builder/conftest.py

+14
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,20 @@ def mock_index_endpoint_find_neighbors(mock_index_endpoint):
13821382
yield mock_find_neighbors
13831383

13841384

1385+
@pytest.fixture
1386+
def mock_index_endpoint_match(mock_index_endpoint):
1387+
with patch.object(mock_index_endpoint, "match") as mock:
1388+
mock.return_value = None
1389+
yield mock
1390+
1391+
1392+
@pytest.fixture
1393+
def mock_index_endpoint_read_index_datapoints(mock_index_endpoint):
1394+
with patch.object(mock_index_endpoint, "read_index_datapoints") as mock:
1395+
mock.return_value = None
1396+
yield mock
1397+
1398+
13851399
@pytest.fixture
13861400
def mock_index_create_tree_ah_index(mock_index):
13871401
with patch.object(

samples/model-builder/test_constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,4 @@
407407
VECTOR_SEARCH_INDEX_LABELS = {"my_key": "my_value"}
408408
VECTOR_SEARCH_GCS_URI = "gs://fake-dir"
409409
VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME = "my-vector-search-index-endpoint"
410+
VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT = "fake-signed-jwt"

samples/model-builder/vector_search/vector_search_find_neighbors_sample.py

+50
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,53 @@ def vector_search_find_neighbors(
8484
print(hybrid_resp)
8585

8686
# [END aiplatform_sdk_vector_search_find_neighbors_sample]
87+
88+
89+
# [START aiplatform_sdk_vector_search_find_neighbors_jwt_sample]
90+
def vector_search_find_neighbors_jwt(
91+
project: str,
92+
location: str,
93+
index_endpoint_name: str,
94+
deployed_index_id: str,
95+
queries: List[List[float]],
96+
num_neighbors: int,
97+
signed_jwt: str,
98+
) -> List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]]:
99+
"""Query the vector search index.
100+
101+
Args:
102+
project (str): Required. Project ID
103+
location (str): Required. The region name
104+
index_endpoint_name (str): Required. Index endpoint to run the query
105+
against.
106+
deployed_index_id (str): Required. The ID of the DeployedIndex to run
107+
the queries against.
108+
queries (List[List[float]]): Required. A list of queries. Each query is
109+
a list of floats, representing a single embedding.
110+
num_neighbors (int): Required. The number of neighbors to return.
111+
signed_jwt (str): Required. The signed JWT token for the private
112+
endpoint. The endpoint must be configured to accept tokens from JWT's
113+
issuer and encoded audience.
114+
115+
Returns:
116+
List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]] - A list of nearest neighbors for each query.
117+
"""
118+
# Initialize the Vertex AI client
119+
aiplatform.init(project=project, location=location)
120+
121+
# Create the index endpoint instance from an existing endpoint.
122+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
123+
index_endpoint_name=index_endpoint_name
124+
)
125+
126+
# Query the index endpoint for the nearest neighbors.
127+
resp = my_index_endpoint.find_neighbors(
128+
deployed_index_id=deployed_index_id,
129+
queries=queries,
130+
num_neighbors=num_neighbors,
131+
signed_jwt=signed_jwt,
132+
)
133+
return resp
134+
135+
# [END aiplatform_sdk_vector_search_find_neighbors_jwt_sample]
136+

samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,34 @@ def test_vector_search_find_neighbors_sample(
5555
],
5656
any_order=False,
5757
)
58+
59+
60+
def test_vector_search_find_neighbors_jwt_sample(
61+
mock_sdk_init, mock_index_endpoint_init, mock_index_endpoint_find_neighbors
62+
):
63+
vector_search_find_neighbors_sample.vector_search_find_neighbors_jwt(
64+
project=constants.PROJECT,
65+
location=constants.LOCATION,
66+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT,
67+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
68+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
69+
num_neighbors=10,
70+
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
71+
)
72+
73+
# Check client initialization
74+
mock_sdk_init.assert_called_with(
75+
project=constants.PROJECT, location=constants.LOCATION
76+
)
77+
78+
# Check index endpoint initialization with right index endpoint name
79+
mock_index_endpoint_init.assert_called_with(
80+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)
81+
82+
# Check index_endpoint.find_neighbors is called with right params.
83+
mock_index_endpoint_find_neighbors.assert_called_with(
84+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
85+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
86+
num_neighbors=10,
87+
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
88+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_vector_search_match_jwt_sample]
21+
def vector_search_match_jwt(
22+
project: str,
23+
location: str,
24+
index_endpoint_name: str,
25+
deployed_index_id: str,
26+
queries: List[List[float]],
27+
num_neighbors: int,
28+
signed_jwt: str,
29+
) -> List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]]:
30+
"""Query the vector search index.
31+
32+
Args:
33+
project (str): Required. Project ID
34+
location (str): Required. The region name
35+
index_endpoint_name (str): Required. Index endpoint to run the query
36+
against. The endpoint must be a private endpoint.
37+
deployed_index_id (str): Required. The ID of the DeployedIndex to run
38+
the queries against.
39+
queries (List[List[float]]): Required. A list of queries. Each query is
40+
a list of floats, representing a single embedding.
41+
num_neighbors (int): Required. The number of neighbors to return.
42+
signed_jwt (str): Required. The signed JWT token for the private
43+
endpoint. The endpoint must be configured to accept tokens from JWT's
44+
issuer and encoded audience.
45+
46+
Returns:
47+
List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]] - A list of nearest neighbors for each query.
48+
"""
49+
# Initialize the Vertex AI client
50+
aiplatform.init(project=project, location=location)
51+
52+
# Create the index endpoint instance from an existing endpoint.
53+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
54+
index_endpoint_name=index_endpoint_name
55+
)
56+
57+
# Query the index endpoint for matches.
58+
resp = my_index_endpoint.match(
59+
deployed_index_id=deployed_index_id,
60+
queries=queries,
61+
num_neighbors=num_neighbors,
62+
signed_jwt=signed_jwt,
63+
)
64+
return resp
65+
66+
# [END aiplatform_sdk_vector_search_match_jwt_sample]
67+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import test_constants as constants
16+
from vector_search import vector_search_match_sample
17+
18+
19+
def test_vector_search_match_jwt_sample(
20+
mock_sdk_init, mock_index_endpoint_init, mock_index_endpoint_match
21+
):
22+
vector_search_match_sample.vector_search_match_jwt(
23+
project=constants.PROJECT,
24+
location=constants.LOCATION,
25+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT,
26+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
27+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
28+
num_neighbors=10,
29+
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
30+
)
31+
32+
# Check client initialization
33+
mock_sdk_init.assert_called_with(
34+
project=constants.PROJECT, location=constants.LOCATION
35+
)
36+
37+
# Check index endpoint initialization with right index endpoint name
38+
mock_index_endpoint_init.assert_called_with(
39+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)
40+
41+
# Check index_endpoint.match is called with right params.
42+
mock_index_endpoint_match.assert_called_with(
43+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
44+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
45+
num_neighbors=10,
46+
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
47+
)

0 commit comments

Comments
 (0)