Skip to content

Commit e3a87f0

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add support for find_neighbors/read_index_datapoints in matching engine public endpoint
PiperOrigin-RevId: 527357229
1 parent a8ba666 commit e3a87f0

File tree

7 files changed

+280
-3
lines changed

7 files changed

+280
-3
lines changed

google/cloud/aiplatform/compat/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
services.specialist_pool_service_client = (
4242
services.specialist_pool_service_client_v1beta1
4343
)
44+
services.match_service_client = services.match_service_client_v1beta1
4445
services.metadata_service_client = services.metadata_service_client_v1beta1
4546
services.tensorboard_service_client = services.tensorboard_service_client_v1beta1
4647
services.index_service_client = services.index_service_client_v1beta1

google/cloud/aiplatform/compat/services/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
from google.cloud.aiplatform_v1beta1.services.job_service import (
4040
client as job_service_client_v1beta1,
4141
)
42+
from google.cloud.aiplatform_v1beta1.services.match_service import (
43+
client as match_service_client_v1beta1,
44+
)
4245
from google.cloud.aiplatform_v1beta1.services.metadata_service import (
4346
client as metadata_service_client_v1beta1,
4447
)
@@ -129,6 +132,7 @@
129132
index_service_client_v1beta1,
130133
index_endpoint_service_client_v1beta1,
131134
job_service_client_v1beta1,
135+
match_service_client_v1beta1,
132136
model_service_client_v1beta1,
133137
pipeline_service_client_v1beta1,
134138
prediction_service_client_v1beta1,

google/cloud/aiplatform/compat/types/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
lineage_subgraph as lineage_subgraph_v1beta1,
5858
machine_resources as machine_resources_v1beta1,
5959
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1,
60+
match_service as match_service_v1beta1,
6061
metadata_schema as metadata_schema_v1beta1,
6162
metadata_service as metadata_service_v1beta1,
6263
metadata_store as metadata_store_v1beta1,
@@ -260,6 +261,7 @@
260261
matching_engine_deployed_index_ref_v1beta1,
261262
index_v1beta1,
262263
index_endpoint_v1beta1,
264+
match_service_v1beta1,
263265
metadata_service_v1beta1,
264266
metadata_schema_v1beta1,
265267
metadata_store_v1beta1,

google/cloud/aiplatform/initializer.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def get_client_options(
279279
location_override: Optional[str] = None,
280280
prediction_client: bool = False,
281281
api_base_path_override: Optional[str] = None,
282+
api_path_override: Optional[str] = None,
282283
) -> client_options.ClientOptions:
283284
"""Creates GAPIC client_options using location and type.
284285
@@ -289,6 +290,7 @@ def get_client_options(
289290
Vertex AI.
290291
prediction_client (str): Optional. flag to use a prediction endpoint.
291292
api_base_path_override (str): Optional. Override default API base path.
293+
api_path_override (str): Optional. Override default api path.
292294
Returns:
293295
clients_options (google.api_core.client_options.ClientOptions):
294296
A ClientOptions object set with regionalized API endpoint, i.e.
@@ -311,9 +313,12 @@ def get_client_options(
311313
else constants.API_BASE_PATH
312314
)
313315

314-
return client_options.ClientOptions(
315-
api_endpoint=f"{region}-{service_base_path}"
316+
api_endpoint = (
317+
f"{region}-{service_base_path}"
318+
if not api_path_override
319+
else api_path_override
316320
)
321+
return client_options.ClientOptions(api_endpoint=api_endpoint)
317322

318323
def common_location_path(
319324
self, project: Optional[str] = None, location: Optional[str] = None
@@ -345,6 +350,7 @@ def create_client(
345350
location_override: Optional[str] = None,
346351
prediction_client: bool = False,
347352
api_base_path_override: Optional[str] = None,
353+
api_path_override: Optional[str] = None,
348354
appended_user_agent: Optional[List[str]] = None,
349355
) -> utils.VertexAiServiceClientWithOverride:
350356
"""Instantiates a given VertexAiServiceClient with optional
@@ -358,6 +364,7 @@ def create_client(
358364
location_override (str): Optional. location override.
359365
prediction_client (str): Optional. flag to use a prediction endpoint.
360366
api_base_path_override (str): Optional. Override default api base path.
367+
api_path_override (str): Optional. Override default api path.
361368
appended_user_agent (List[str]):
362369
Optional. User agent appended in the client info. If more than one, it will be
363370
separated by spaces.
@@ -383,6 +390,7 @@ def create_client(
383390
location_override=location_override,
384391
prediction_client=prediction_client,
385392
api_base_path_override=api_base_path_override,
393+
api_path_override=api_path_override,
386394
),
387395
"client_info": client_info,
388396
}

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+139
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from google.cloud.aiplatform.compat.types import (
2727
machine_resources as gca_machine_resources_compat,
2828
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
29+
match_service_v1beta1 as gca_match_service_v1beta1,
30+
index_v1beta1 as gca_index_v1beta1,
2931
)
3032
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
3133
from google.cloud.aiplatform.matching_engine._protos import (
@@ -127,6 +129,9 @@ def __init__(
127129
)
128130
self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)
129131

132+
if self.public_endpoint_domain_name:
133+
self._public_match_client = self._instantiate_public_match_client()
134+
130135
@classmethod
131136
def create(
132137
cls,
@@ -344,6 +349,22 @@ def _create(
344349

345350
return index_obj
346351

352+
def _instantiate_public_match_client(
353+
self,
354+
) -> utils.MatchClientWithOverride:
355+
"""Helper method to instantiates match client with optional
356+
overrides for this endpoint.
357+
Returns:
358+
match_client (match_service_client.MatchServiceClient):
359+
Initialized match client with optional overrides.
360+
"""
361+
return initializer.global_config.create_client(
362+
client_class=utils.MatchClientWithOverride,
363+
credentials=self.credentials,
364+
location_override=self.location,
365+
api_path_override=self.public_endpoint_domain_name,
366+
)
367+
347368
@property
348369
def public_endpoint_domain_name(self) -> Optional[str]:
349370
"""Public endpoint DNS name."""
@@ -928,6 +949,124 @@ def description(self) -> str:
928949
self._assert_gca_resource_is_available()
929950
return self._gca_resource.description
930951

952+
def find_neighbors(
953+
self,
954+
*,
955+
deployed_index_id: str,
956+
queries: List[List[float]],
957+
num_neighbors: int = 10,
958+
filter: Optional[List[Namespace]] = [],
959+
) -> List[List[MatchNeighbor]]:
960+
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
961+
962+
```
963+
Example usage:
964+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
965+
index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id'
966+
)
967+
my_index_endpoint.find_neighbors(deployed_index_id="public_test1", queries= [[1, 1]],)
968+
```
969+
Args:
970+
deployed_index_id (str):
971+
Required. The ID of the DeployedIndex to match the queries against.
972+
queries (List[List[float]]):
973+
Required. A list of queries. Each query is a list of floats, representing a single embedding.
974+
num_neighbors (int):
975+
Required. The number of nearest neighbors to be retrieved from database for
976+
each query.
977+
filter (List[Namespace]):
978+
Optional. A list of Namespaces for filtering the matching results.
979+
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
980+
that satisfy "red color" but not include datapoints with "squared shape".
981+
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
982+
Returns:
983+
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
984+
"""
985+
986+
if not self._public_match_client:
987+
raise ValueError(
988+
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
989+
)
990+
991+
# Create the FindNeighbors request
992+
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest()
993+
find_neighbors_request.index_endpoint = self.resource_name
994+
find_neighbors_request.deployed_index_id = deployed_index_id
995+
996+
for query in queries:
997+
find_neighbors_query = (
998+
gca_match_service_v1beta1.FindNeighborsRequest.Query()
999+
)
1000+
find_neighbors_query.neighbor_count = num_neighbors
1001+
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
1002+
for namespace in filter:
1003+
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
1004+
restrict.namespace = namespace.name
1005+
restrict.allow_list.extend(namespace.allow_tokens)
1006+
restrict.deny_list.extend(namespace.deny_tokens)
1007+
datapoint.restricts.append(restrict)
1008+
find_neighbors_query.datapoint = datapoint
1009+
find_neighbors_request.queries.append(find_neighbors_query)
1010+
1011+
response = self._public_match_client.find_neighbors(find_neighbors_request)
1012+
1013+
# Wrap the results in MatchNeighbor objects and return
1014+
return [
1015+
[
1016+
MatchNeighbor(
1017+
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
1018+
)
1019+
for neighbor in embedding_neighbors.neighbors
1020+
]
1021+
for embedding_neighbors in response.nearest_neighbors
1022+
]
1023+
1024+
def read_index_datapoints(
1025+
self,
1026+
*,
1027+
deployed_index_id: str,
1028+
ids: List[str] = [],
1029+
) -> List[gca_index_v1beta1.IndexDatapoint]:
1030+
"""Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint.
1031+
1032+
```
1033+
Example Usage:
1034+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1035+
index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id'
1036+
)
1037+
my_index_endpoint.read_index_datapoints(deployed_index_id="public_test1", ids= ["606431", "896688"],)
1038+
```
1039+
1040+
Args:
1041+
deployed_index_id (str):
1042+
Required. The ID of the DeployedIndex to match the queries against.
1043+
ids (List[str]):
1044+
Required. IDs of the datapoints to be searched for.
1045+
Returns:
1046+
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
1047+
"""
1048+
if not self._public_match_client:
1049+
raise ValueError(
1050+
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
1051+
)
1052+
1053+
# Create the ReadIndexDatapoints request
1054+
read_index_datapoints_request = (
1055+
gca_match_service_v1beta1.ReadIndexDatapointsRequest()
1056+
)
1057+
read_index_datapoints_request.index_endpoint = self.resource_name
1058+
read_index_datapoints_request.deployed_index_id = deployed_index_id
1059+
1060+
for id in ids:
1061+
read_index_datapoints_request.ids.append(id)
1062+
1063+
response = self._public_match_client.read_index_datapoints(
1064+
read_index_datapoints_request
1065+
)
1066+
1067+
# Wrap the results and return
1068+
return response.datapoints
1069+
9311070
def match(
9321071
self,
9331072
deployed_index_id: str,

google/cloud/aiplatform/utils/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
index_service_client_v1beta1,
4545
index_endpoint_service_client_v1beta1,
4646
job_service_client_v1beta1,
47+
match_service_client_v1beta1,
4748
metadata_service_client_v1beta1,
4849
model_service_client_v1beta1,
4950
pipeline_service_client_v1beta1,
@@ -85,6 +86,7 @@
8586
prediction_service_client_v1beta1.PredictionServiceClient,
8687
pipeline_service_client_v1beta1.PipelineServiceClient,
8788
job_service_client_v1beta1.JobServiceClient,
89+
match_service_client_v1beta1.MatchServiceClient,
8890
metadata_service_client_v1beta1.MetadataServiceClient,
8991
tensorboard_service_client_v1beta1.TensorboardServiceClient,
9092
vizier_service_client_v1beta1.VizierServiceClient,
@@ -598,6 +600,12 @@ class PredictionClientWithOverride(ClientWithOverride):
598600
)
599601

600602

603+
class MatchClientWithOverride(ClientWithOverride):
604+
_is_temporary = False
605+
_default_version = compat.V1BETA1
606+
_version_map = ((compat.V1BETA1, match_service_client_v1beta1.MatchServiceClient),)
607+
608+
601609
class MetadataClientWithOverride(ClientWithOverride):
602610
_is_temporary = True
603611
_default_version = compat.DEFAULT_VERSION
@@ -632,6 +640,7 @@ class VizierClientWithOverride(ClientWithOverride):
632640
FeaturestoreClientWithOverride,
633641
JobClientWithOverride,
634642
ModelClientWithOverride,
643+
MatchClientWithOverride,
635644
PipelineClientWithOverride,
636645
PipelineJobClientWithOverride,
637646
PredictionClientWithOverride,

0 commit comments

Comments
 (0)