Skip to content

Commit ea8ae2d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: Samples - Add vector search sample for PSC match queries
PiperOrigin-RevId: 702752681
1 parent 7432c2c commit ea8ae2d

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

samples/model-builder/test_constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,4 @@
421421
("test-project", "network1"),
422422
("test-project2", "network2"),
423423
]
424+
VECTOR_SEARCH_PSC_MANUAL_IP_ADDRESS = "1.2.3.4"

samples/model-builder/vector_search/vector_search_match_sample.py

+103
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,106 @@ def vector_search_match_jwt(
128128
return resp
129129

130130
# [END aiplatform_sdk_vector_search_match_jwt_sample]
131+
132+
133+
# [START aiplatform_sdk_vector_search_match_psc_manual_sample]
134+
def vector_search_match_psc_manual(
135+
project: str,
136+
location: str,
137+
index_endpoint_name: str,
138+
deployed_index_id: str,
139+
queries: List[List[float]],
140+
num_neighbors: int,
141+
ip_address: str,
142+
) -> List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]]:
143+
"""Query the vector search index deployed with PSC manual configuration.
144+
145+
Args:
146+
project (str): Required. Project ID
147+
location (str): Required. The region name
148+
index_endpoint_name (str): Required. Index endpoint to run the query
149+
against. The endpoint must be a private endpoint.
150+
deployed_index_id (str): Required. The ID of the DeployedIndex to run
151+
the queries against.
152+
queries (List[List[float]]): Required. A list of queries. Each query is
153+
a list of floats, representing a single embedding.
154+
num_neighbors (int): Required. The number of neighbors to return.
155+
ip_address (str): Required. The IP address of the PSC endpoint. Obtained
156+
from the created compute address used in the forwarding rule to the
157+
endpoint's service attachment.
158+
159+
Returns:
160+
List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]] - A list of nearest neighbors for each query.
161+
"""
162+
# Initialize the Vertex AI client
163+
aiplatform.init(project=project, location=location)
164+
165+
# Create the index endpoint instance from an existing endpoint.
166+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
167+
index_endpoint_name=index_endpoint_name
168+
)
169+
170+
# Set the IP address of the PSC endpoint.
171+
my_index_endpoint.private_service_connect_ip_address = ip_address
172+
173+
# Query the index endpoint for matches.
174+
resp = my_index_endpoint.match(
175+
deployed_index_id=deployed_index_id,
176+
queries=queries,
177+
num_neighbors=num_neighbors
178+
)
179+
return resp
180+
181+
# [END aiplatform_sdk_vector_search_match_psc_manual_sample]
182+
183+
184+
# [START aiplatform_sdk_vector_search_match_psc_automation_sample]
185+
def vector_search_match_psc_automation(
186+
project: str,
187+
location: str,
188+
index_endpoint_name: str,
189+
deployed_index_id: str,
190+
queries: List[List[float]],
191+
num_neighbors: int,
192+
psc_network: str,
193+
) -> List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]]:
194+
"""Query the vector search index deployed with PSC automation.
195+
196+
Args:
197+
project (str): Required. Project ID
198+
location (str): Required. The region name
199+
index_endpoint_name (str): Required. Index endpoint to run the query
200+
against. The endpoint must be a private endpoint.
201+
deployed_index_id (str): Required. The ID of the DeployedIndex to run
202+
the queries against.
203+
queries (List[List[float]]): Required. A list of queries. Each query is
204+
a list of floats, representing a single embedding.
205+
num_neighbors (int): Required. The number of neighbors to return.
206+
ip_address (str): Required. The IP address of the PSC endpoint. Obtained
207+
from the created compute address used in the fordwarding rule to the
208+
endpoint's service attachment.
209+
psc_network (str): The network the endpoint was deployed to via PSC
210+
automation configuration. The format is
211+
projects/{project_id}/global/networks/{network_name}.
212+
213+
Returns:
214+
List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]] - A list of nearest neighbors for each query.
215+
"""
216+
# Initialize the Vertex AI client
217+
aiplatform.init(project=project, location=location)
218+
219+
# Create the index endpoint instance from an existing endpoint.
220+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
221+
index_endpoint_name=index_endpoint_name
222+
)
223+
224+
# Query the index endpoint for matches.
225+
resp = my_index_endpoint.match(
226+
deployed_index_id=deployed_index_id,
227+
queries=queries,
228+
num_neighbors=num_neighbors,
229+
psc_network=psc_network
230+
)
231+
return resp
232+
233+
# [END aiplatform_sdk_vector_search_match_psc_automation_sample]

samples/model-builder/vector_search/vector_search_match_sample_test.py

+69
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,72 @@ def test_vector_search_match_jwt_sample(
7575
num_neighbors=10,
7676
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
7777
)
78+
79+
80+
def test_vector_search_match_psc_manual_sample(
81+
mock_sdk_init,
82+
mock_index_endpoint,
83+
mock_index_endpoint_init,
84+
mock_index_endpoint_match
85+
):
86+
vector_search_match_sample.vector_search_match_psc_manual(
87+
project=constants.PROJECT,
88+
location=constants.LOCATION,
89+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT,
90+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
91+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
92+
num_neighbors=10,
93+
ip_address=constants.VECTOR_SEARCH_PSC_MANUAL_IP_ADDRESS,
94+
)
95+
96+
# Check client initialization
97+
mock_sdk_init.assert_called_with(
98+
project=constants.PROJECT, location=constants.LOCATION
99+
)
100+
101+
# Check index endpoint initialization with right index endpoint name
102+
mock_index_endpoint_init.assert_called_with(
103+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)
104+
105+
# Check index endpoint PSC IP address is set
106+
assert mock_index_endpoint.private_service_connect_ip_address == (
107+
constants.VECTOR_SEARCH_PSC_MANUAL_IP_ADDRESS
108+
)
109+
110+
# Check index_endpoint.match is called with right params.
111+
mock_index_endpoint_match.assert_called_with(
112+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
113+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
114+
num_neighbors=10,
115+
)
116+
117+
118+
def test_vector_search_match_psc_automation_sample(
119+
mock_sdk_init, mock_index_endpoint_init, mock_index_endpoint_match
120+
):
121+
vector_search_match_sample.vector_search_match_psc_automation(
122+
project=constants.PROJECT,
123+
location=constants.LOCATION,
124+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT,
125+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
126+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
127+
num_neighbors=10,
128+
psc_network=constants.VECTOR_SEARCH_VPC_NETWORK,
129+
)
130+
131+
# Check client initialization
132+
mock_sdk_init.assert_called_with(
133+
project=constants.PROJECT, location=constants.LOCATION
134+
)
135+
136+
# Check index endpoint initialization with right index endpoint name
137+
mock_index_endpoint_init.assert_called_with(
138+
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)
139+
140+
# Check index_endpoint.match is called with right params.
141+
mock_index_endpoint_match.assert_called_with(
142+
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
143+
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
144+
num_neighbors=10,
145+
psc_network=constants.VECTOR_SEARCH_VPC_NETWORK,
146+
)

0 commit comments

Comments
 (0)