|
26 | 26 | from google.cloud.aiplatform.compat.types import (
|
27 | 27 | machine_resources as gca_machine_resources_compat,
|
28 | 28 | matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
|
| 29 | + match_service_v1beta1 as gca_match_service_v1beta1, |
| 30 | + index_v1beta1 as gca_index_v1beta1, |
29 | 31 | )
|
30 | 32 | from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
|
31 | 33 | from google.cloud.aiplatform.matching_engine._protos import (
|
@@ -127,6 +129,9 @@ def __init__(
|
127 | 129 | )
|
128 | 130 | self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)
|
129 | 131 |
|
| 132 | + if self.public_endpoint_domain_name: |
| 133 | + self._public_match_client = self._instantiate_public_match_client() |
| 134 | + |
130 | 135 | @classmethod
|
131 | 136 | def create(
|
132 | 137 | cls,
|
@@ -344,6 +349,22 @@ def _create(
|
344 | 349 |
|
345 | 350 | return index_obj
|
346 | 351 |
|
| 352 | + def _instantiate_public_match_client( |
| 353 | + self, |
| 354 | + ) -> utils.MatchClientWithOverride: |
| 355 | + """Helper method to instantiates match client with optional |
| 356 | + overrides for this endpoint. |
| 357 | + Returns: |
| 358 | + match_client (match_service_client.MatchServiceClient): |
| 359 | + Initialized match client with optional overrides. |
| 360 | + """ |
| 361 | + return initializer.global_config.create_client( |
| 362 | + client_class=utils.MatchClientWithOverride, |
| 363 | + credentials=self.credentials, |
| 364 | + location_override=self.location, |
| 365 | + api_path_override=self.public_endpoint_domain_name, |
| 366 | + ) |
| 367 | + |
347 | 368 | @property
|
348 | 369 | def public_endpoint_domain_name(self) -> Optional[str]:
|
349 | 370 | """Public endpoint DNS name."""
|
@@ -928,6 +949,124 @@ def description(self) -> str:
|
928 | 949 | self._assert_gca_resource_is_available()
|
929 | 950 | return self._gca_resource.description
|
930 | 951 |
|
| 952 | + def find_neighbors( |
| 953 | + self, |
| 954 | + *, |
| 955 | + deployed_index_id: str, |
| 956 | + queries: List[List[float]], |
| 957 | + num_neighbors: int = 10, |
| 958 | + filter: Optional[List[Namespace]] = [], |
| 959 | + ) -> List[List[MatchNeighbor]]: |
| 960 | + """Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint. |
| 961 | +
|
| 962 | + ``` |
| 963 | + Example usage: |
| 964 | + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( |
| 965 | + index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id' |
| 966 | + ) |
| 967 | + my_index_endpoint.find_neighbors(deployed_index_id="public_test1", queries= [[1, 1]],) |
| 968 | + ``` |
| 969 | + Args: |
| 970 | + deployed_index_id (str): |
| 971 | + Required. The ID of the DeployedIndex to match the queries against. |
| 972 | + queries (List[List[float]]): |
| 973 | + Required. A list of queries. Each query is a list of floats, representing a single embedding. |
| 974 | + num_neighbors (int): |
| 975 | + Required. The number of nearest neighbors to be retrieved from database for |
| 976 | + each query. |
| 977 | + filter (List[Namespace]): |
| 978 | + Optional. A list of Namespaces for filtering the matching results. |
| 979 | + For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints |
| 980 | + that satisfy "red color" but not include datapoints with "squared shape". |
| 981 | + Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. |
| 982 | + Returns: |
| 983 | + List[List[MatchNeighbor]] - A list of nearest neighbors for each query. |
| 984 | + """ |
| 985 | + |
| 986 | + if not self._public_match_client: |
| 987 | + raise ValueError( |
| 988 | + "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method." |
| 989 | + ) |
| 990 | + |
| 991 | + # Create the FindNeighbors request |
| 992 | + find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest() |
| 993 | + find_neighbors_request.index_endpoint = self.resource_name |
| 994 | + find_neighbors_request.deployed_index_id = deployed_index_id |
| 995 | + |
| 996 | + for query in queries: |
| 997 | + find_neighbors_query = ( |
| 998 | + gca_match_service_v1beta1.FindNeighborsRequest.Query() |
| 999 | + ) |
| 1000 | + find_neighbors_query.neighbor_count = num_neighbors |
| 1001 | + datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query) |
| 1002 | + for namespace in filter: |
| 1003 | + restrict = gca_index_v1beta1.IndexDatapoint.Restriction() |
| 1004 | + restrict.namespace = namespace.name |
| 1005 | + restrict.allow_list.extend(namespace.allow_tokens) |
| 1006 | + restrict.deny_list.extend(namespace.deny_tokens) |
| 1007 | + datapoint.restricts.append(restrict) |
| 1008 | + find_neighbors_query.datapoint = datapoint |
| 1009 | + find_neighbors_request.queries.append(find_neighbors_query) |
| 1010 | + |
| 1011 | + response = self._public_match_client.find_neighbors(find_neighbors_request) |
| 1012 | + |
| 1013 | + # Wrap the results in MatchNeighbor objects and return |
| 1014 | + return [ |
| 1015 | + [ |
| 1016 | + MatchNeighbor( |
| 1017 | + id=neighbor.datapoint.datapoint_id, distance=neighbor.distance |
| 1018 | + ) |
| 1019 | + for neighbor in embedding_neighbors.neighbors |
| 1020 | + ] |
| 1021 | + for embedding_neighbors in response.nearest_neighbors |
| 1022 | + ] |
| 1023 | + |
| 1024 | + def read_index_datapoints( |
| 1025 | + self, |
| 1026 | + *, |
| 1027 | + deployed_index_id: str, |
| 1028 | + ids: List[str] = [], |
| 1029 | + ) -> List[gca_index_v1beta1.IndexDatapoint]: |
| 1030 | + """Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint. |
| 1031 | +
|
| 1032 | + ``` |
| 1033 | + Example Usage: |
| 1034 | + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( |
| 1035 | + index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id' |
| 1036 | + ) |
| 1037 | + my_index_endpoint.read_index_datapoints(deployed_index_id="public_test1", ids= ["606431", "896688"],) |
| 1038 | + ``` |
| 1039 | +
|
| 1040 | + Args: |
| 1041 | + deployed_index_id (str): |
| 1042 | + Required. The ID of the DeployedIndex to match the queries against. |
| 1043 | + ids (List[str]): |
| 1044 | + Required. IDs of the datapoints to be searched for. |
| 1045 | + Returns: |
| 1046 | + List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs. |
| 1047 | + """ |
| 1048 | + if not self._public_match_client: |
| 1049 | + raise ValueError( |
| 1050 | + "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method." |
| 1051 | + ) |
| 1052 | + |
| 1053 | + # Create the ReadIndexDatapoints request |
| 1054 | + read_index_datapoints_request = ( |
| 1055 | + gca_match_service_v1beta1.ReadIndexDatapointsRequest() |
| 1056 | + ) |
| 1057 | + read_index_datapoints_request.index_endpoint = self.resource_name |
| 1058 | + read_index_datapoints_request.deployed_index_id = deployed_index_id |
| 1059 | + |
| 1060 | + for id in ids: |
| 1061 | + read_index_datapoints_request.ids.append(id) |
| 1062 | + |
| 1063 | + response = self._public_match_client.read_index_datapoints( |
| 1064 | + read_index_datapoints_request |
| 1065 | + ) |
| 1066 | + |
| 1067 | + # Wrap the results and return |
| 1068 | + return response.datapoints |
| 1069 | + |
931 | 1070 | def match(
|
932 | 1071 | self,
|
933 | 1072 | deployed_index_id: str,
|
|
0 commit comments