Skip to content

Commit 42c7e08

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Add query by id for MatchingEngineIndexEndpoint find_neighbors() public endpoint query.
PiperOrigin-RevId: 599497930
1 parent 67e593b commit 42c7e08

File tree

2 files changed

+122
-53
lines changed

2 files changed

+122
-53
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+70-49
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,13 @@ class MatchNeighbor:
5151
Required. The id of the neighbor.
5252
distance (float):
5353
Required. The distance to the query embedding.
54+
feature_vector (List(float)):
55+
Optional. The feature vector of the matching datapoint.
5456
"""
5557

5658
id: str
5759
distance: float
60+
feature_vector: Optional[List[float]] = None
5861

5962

6063
@dataclass
@@ -1185,14 +1188,15 @@ def find_neighbors(
11851188
self,
11861189
*,
11871190
deployed_index_id: str,
1188-
queries: List[List[float]],
1191+
queries: Optional[List[List[float]]] = None,
11891192
num_neighbors: int = 10,
11901193
filter: Optional[List[Namespace]] = None,
11911194
per_crowding_attribute_neighbor_count: Optional[int] = None,
11921195
approx_num_neighbors: Optional[int] = None,
11931196
fraction_leaf_nodes_to_search_override: Optional[float] = None,
11941197
return_full_datapoint: bool = False,
11951198
numeric_filter: Optional[List[NumericNamespace]] = None,
1199+
embedding_ids: Optional[List[str]] = None,
11961200
) -> List[List[MatchNeighbor]]:
11971201
"""Retrieves nearest neighbors for the given embedding queries on the
11981202
specified deployed index which is deployed to either public or private
@@ -1243,11 +1247,18 @@ def find_neighbors(
12431247
Note that returning full datapoint will significantly increase the
12441248
latency and cost of the query.
12451249
1246-
numeric_filter (Optional[list[NumericNamespace]]):
1250+
numeric_filter (list[NumericNamespace]):
12471251
Optional. A list of NumericNamespaces for filtering the matching
12481252
results. For example:
12491253
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
12501254
will match datapoints that its cost is greater than 5.
1255+
1256+
embedding_ids (str):
1257+
Optional. If `queries` is set, will use `queries` to do nearest
1258+
neighbor search. If `queries` isn't set, will first use
1259+
`embedding_ids` to lookup embedding values from dataset, if embedding
1260+
with `embedding_ids` exists in the dataset, do nearest neighbor search.
1261+
12511262
Returns:
12521263
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
12531264
"""
@@ -1262,7 +1273,6 @@ def find_neighbors(
12621273
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
12631274
approx_num_neighbors=approx_num_neighbors,
12641275
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1265-
return_full_datapoint=return_full_datapoint,
12661276
)
12671277

12681278
# Create the FindNeighbors request
@@ -1271,50 +1281,60 @@ def find_neighbors(
12711281
find_neighbors_request.deployed_index_id = deployed_index_id
12721282
find_neighbors_request.return_full_datapoint = return_full_datapoint
12731283

1274-
for query in queries:
1275-
find_neighbors_query = (
1276-
gca_match_service_v1beta1.FindNeighborsRequest.Query()
1277-
)
1278-
find_neighbors_query.neighbor_count = num_neighbors
1279-
find_neighbors_query.per_crowding_attribute_neighbor_count = (
1280-
per_crowding_attribute_neighbor_count
1281-
)
1282-
find_neighbors_query.approximate_neighbor_count = approx_num_neighbors
1283-
find_neighbors_query.fraction_leaf_nodes_to_search_override = (
1284-
fraction_leaf_nodes_to_search_override
1284+
# Token restricts
1285+
restricts = []
1286+
if filter:
1287+
for namespace in filter:
1288+
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
1289+
restrict.namespace = namespace.name
1290+
restrict.allow_list.extend(namespace.allow_tokens)
1291+
restrict.deny_list.extend(namespace.deny_tokens)
1292+
restricts.append(restrict)
1293+
# Numeric restricts
1294+
numeric_restricts = []
1295+
if numeric_filter:
1296+
for numeric_namespace in numeric_filter:
1297+
numeric_restrict = gca_index_v1beta1.IndexDatapoint.NumericRestriction()
1298+
numeric_restrict.namespace = numeric_namespace.name
1299+
numeric_restrict.op = numeric_namespace.op
1300+
numeric_restrict.value_int = numeric_namespace.value_int
1301+
numeric_restrict.value_float = numeric_namespace.value_float
1302+
numeric_restrict.value_double = numeric_namespace.value_double
1303+
numeric_restricts.append(numeric_restrict)
1304+
# Queries
1305+
query_by_id = False if queries else True
1306+
queries = queries if queries else embedding_ids
1307+
if queries:
1308+
for query in queries:
1309+
find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query(
1310+
neighbor_count=num_neighbors,
1311+
per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count,
1312+
approximate_neighbor_count=approx_num_neighbors,
1313+
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1314+
)
1315+
datapoint = gca_index_v1beta1.IndexDatapoint(
1316+
datapoint_id=query if query_by_id else None,
1317+
feature_vector=None if query_by_id else query,
1318+
)
1319+
datapoint.restricts.extend(restricts)
1320+
datapoint.numeric_restricts.extend(numeric_restricts)
1321+
find_neighbors_query.datapoint = datapoint
1322+
find_neighbors_request.queries.append(find_neighbors_query)
1323+
else:
1324+
raise ValueError(
1325+
"To find neighbors using matching engine,"
1326+
"please specify `queries` or `embedding_ids`"
12851327
)
1286-
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
1287-
# Token restricts
1288-
if filter:
1289-
for namespace in filter:
1290-
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
1291-
restrict.namespace = namespace.name
1292-
restrict.allow_list.extend(namespace.allow_tokens)
1293-
restrict.deny_list.extend(namespace.deny_tokens)
1294-
datapoint.restricts.append(restrict)
1295-
# Numeric restricts
1296-
if numeric_filter:
1297-
for numeric_namespace in numeric_filter:
1298-
numeric_restrict = (
1299-
gca_index_v1beta1.IndexDatapoint.NumericRestriction()
1300-
)
1301-
numeric_restrict.namespace = numeric_namespace.name
1302-
numeric_restrict.op = numeric_namespace.op
1303-
numeric_restrict.value_int = numeric_namespace.value_int
1304-
numeric_restrict.value_float = numeric_namespace.value_float
1305-
numeric_restrict.value_double = numeric_namespace.value_double
1306-
datapoint.numeric_restricts.append(numeric_restrict)
1307-
1308-
find_neighbors_query.datapoint = datapoint
1309-
find_neighbors_request.queries.append(find_neighbors_query)
13101328

13111329
response = self._public_match_client.find_neighbors(find_neighbors_request)
13121330

13131331
# Wrap the results in MatchNeighbor objects and return
13141332
return [
13151333
[
13161334
MatchNeighbor(
1317-
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
1335+
id=neighbor.datapoint.datapoint_id,
1336+
distance=neighbor.distance,
1337+
feature_vector=neighbor.datapoint.feature_vector,
13181338
)
13191339
for neighbor in embedding_neighbors.neighbors
13201340
]
@@ -1429,13 +1449,12 @@ def _batch_get_embeddings(
14291449
def match(
14301450
self,
14311451
deployed_index_id: str,
1432-
queries: Optional[List[List[float]]] = None,
1452+
queries: List[List[float]] = None,
14331453
num_neighbors: int = 1,
14341454
filter: Optional[List[Namespace]] = None,
14351455
per_crowding_attribute_num_neighbors: Optional[int] = None,
14361456
approx_num_neighbors: Optional[int] = None,
14371457
fraction_leaf_nodes_to_search_override: Optional[float] = None,
1438-
return_full_datapoint: bool = False,
14391458
low_level_batch_size: int = 0,
14401459
) -> List[List[MatchNeighbor]]:
14411460
"""Retrieves nearest neighbors for the given embedding queries on the
@@ -1468,11 +1487,6 @@ def match(
14681487
query time allows user to tune search performance. This value
14691488
increase result in both search accuracy and latency increase.
14701489
The value should be between 0.0 and 1.0.
1471-
return_full_datapoint (bool):
1472-
Optional. If set to true, the full datapoints (including all
1473-
vector values and of the nearest neighbors are returned.
1474-
Note that returning full datapoint will significantly increase the
1475-
latency and cost of the query.
14761490
low_level_batch_size (int):
14771491
Optional. Selects the optimal batch size to use for low-level
14781492
batching. Queries within each low level batch are executed
@@ -1518,9 +1532,13 @@ def match(
15181532
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
15191533
approx_num_neighbors=approx_num_neighbors,
15201534
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1521-
embedding_enabled=return_full_datapoint,
15221535
)
15231536
requests.append(request)
1537+
else:
1538+
raise ValueError(
1539+
"To find neighbors using matching engine,"
1540+
"please specify `queries` or `embedding_ids`"
1541+
)
15241542

15251543
batch_request_for_index.requests.extend(requests)
15261544
batch_request.requests.append(batch_request_for_index)
@@ -1531,8 +1549,11 @@ def match(
15311549
# Wrap the results in MatchNeighbor objects and return
15321550
return [
15331551
[
1534-
MatchNeighbor(id=neighbor.id, distance=neighbor.distance)
1535-
for neighbor in embedding_neighbors.neighbor
1552+
MatchNeighbor(
1553+
id=embedding_neighbors.neighbor[i].id,
1554+
distance=embedding_neighbors.neighbor[i].distance,
1555+
)
1556+
for i in range(len(embedding_neighbors.neighbor))
15361557
]
15371558
for embedding_neighbors in response.responses[0].responses
15381559
]

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@
230230
-0.021106,
231231
]
232232
]
233+
_TEST_QUERY_IDS = ["1", "2"]
233234
_TEST_NUM_NEIGHBOURS = 1
234235
_TEST_FILTER = [
235236
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
@@ -1044,7 +1045,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
10441045
index_endpoint_match_queries_mock.assert_called_with(batch_request)
10451046

10461047
@pytest.mark.usefixtures("get_index_endpoint_mock")
1047-
def test_private_index_endpoint_match_queries(
1048+
def test_private_service_access_index_endpoint_match_queries(
10481049
self, index_endpoint_match_queries_mock
10491050
):
10501051
aiplatform.init(project=_TEST_PROJECT)
@@ -1061,7 +1062,6 @@ def test_private_index_endpoint_match_queries(
10611062
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
10621063
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
10631064
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1064-
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
10651065
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
10661066
)
10671067

@@ -1085,7 +1085,6 @@ def test_private_index_endpoint_match_queries(
10851085
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
10861086
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
10871087
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1088-
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
10891088
)
10901089
for i in range(len(_TEST_QUERIES))
10911090
],
@@ -1135,7 +1134,6 @@ def test_private_index_endpoint_find_neighbor_queries(
11351134
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
11361135
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
11371136
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1138-
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
11391137
)
11401138
for test_query in _TEST_QUERIES
11411139
],
@@ -1241,6 +1239,56 @@ def test_index_public_endpoint_find_neighbors_queries(
12411239
find_neighbors_request
12421240
)
12431241

1242+
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
1243+
def test_index_public_endpoint_find_neiggbor_query_by_id(
1244+
self, index_public_endpoint_match_queries_mock
1245+
):
1246+
aiplatform.init(project=_TEST_PROJECT)
1247+
1248+
my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1249+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1250+
)
1251+
1252+
my_pubic_index_endpoint.find_neighbors(
1253+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1254+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1255+
filter=_TEST_FILTER,
1256+
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1257+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1258+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1259+
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
1260+
embedding_ids=_TEST_QUERY_IDS,
1261+
)
1262+
1263+
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
1264+
index_endpoint=my_pubic_index_endpoint.resource_name,
1265+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1266+
queries=[
1267+
gca_match_service_v1beta1.FindNeighborsRequest.Query(
1268+
neighbor_count=_TEST_NUM_NEIGHBOURS,
1269+
datapoint=gca_index_v1beta1.IndexDatapoint(
1270+
datapoint_id=_TEST_QUERY_IDS[i],
1271+
restricts=[
1272+
gca_index_v1beta1.IndexDatapoint.Restriction(
1273+
namespace="class",
1274+
allow_list=["token_1"],
1275+
deny_list=["token_2"],
1276+
)
1277+
],
1278+
),
1279+
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1280+
approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS,
1281+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1282+
)
1283+
for i in range(len(_TEST_QUERY_IDS))
1284+
],
1285+
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
1286+
)
1287+
1288+
index_public_endpoint_match_queries_mock.assert_called_with(
1289+
find_neighbors_request
1290+
)
1291+
12441292
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
12451293
def test_index_public_endpoint_match_queries_with_numeric_filtering(
12461294
self, index_public_endpoint_match_queries_mock

0 commit comments

Comments
 (0)