Skip to content

Commit ad8d9c1

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Add return_full_datapoint for MatchEngineIndexEndpoint match().
PiperOrigin-RevId: 597148566
1 parent d0f65fd commit ad8d9c1

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+8
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,7 @@ def find_neighbors(
12621262
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
12631263
approx_num_neighbors=approx_num_neighbors,
12641264
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1265+
return_full_datapoint=return_full_datapoint,
12651266
)
12661267

12671268
# Create the FindNeighbors request
@@ -1434,6 +1435,7 @@ def match(
14341435
per_crowding_attribute_num_neighbors: Optional[int] = None,
14351436
approx_num_neighbors: Optional[int] = None,
14361437
fraction_leaf_nodes_to_search_override: Optional[float] = None,
1438+
return_full_datapoint: bool = False,
14371439
) -> List[List[MatchNeighbor]]:
14381440
"""Retrieves nearest neighbors for the given embedding queries on the
14391441
specified deployed index for private endpoint only.
@@ -1465,6 +1467,11 @@ def match(
14651467
query time allows user to tune search performance. This value
14661468
increase result in both search accuracy and latency increase.
14671469
The value should be between 0.0 and 1.0.
1470+
return_full_datapoint (bool):
1471+
Optional. If set to true, the full datapoints (including all
1472+
vector values and of the nearest neighbors are returned.
1473+
Note that returning full datapoint will significantly increase the
1474+
latency and cost of the query.
14681475
14691476
Returns:
14701477
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
@@ -1502,6 +1509,7 @@ def match(
15021509
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
15031510
approx_num_neighbors=approx_num_neighbors,
15041511
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1512+
embedding_enabled=return_full_datapoint,
15051513
)
15061514
requests.append(request)
15071515

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ def test_private_index_endpoint_match_queries(
10601060
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
10611061
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
10621062
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1063+
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
10631064
)
10641065

10651066
batch_request = match_service_pb2.BatchMatchRequest(
@@ -1081,6 +1082,7 @@ def test_private_index_endpoint_match_queries(
10811082
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
10821083
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
10831084
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1085+
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
10841086
)
10851087
for i in range(len(_TEST_QUERIES))
10861088
],
@@ -1096,11 +1098,11 @@ def test_private_index_endpoint_find_neighbor_queries(
10961098
):
10971099
aiplatform.init(project=_TEST_PROJECT)
10981100

1099-
my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1101+
my_private_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
11001102
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
11011103
)
11021104

1103-
my_pubic_index_endpoint.find_neighbors(
1105+
my_private_index_endpoint.find_neighbors(
11041106
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
11051107
queries=_TEST_QUERIES,
11061108
num_neighbors=_TEST_NUM_NEIGHBOURS,
@@ -1130,6 +1132,7 @@ def test_private_index_endpoint_find_neighbor_queries(
11301132
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
11311133
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
11321134
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1135+
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
11331136
)
11341137
for test_query in _TEST_QUERIES
11351138
],
@@ -1187,16 +1190,16 @@ def test_index_private_service_connect_endpoint_match_queries(
11871190
index_endpoint_match_queries_mock.assert_called_with(batch_request)
11881191

11891192
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
1190-
def test_index_public_endpoint_match_queries(
1193+
def test_index_public_endpoint_find_neighbors_queries(
11911194
self, index_public_endpoint_match_queries_mock
11921195
):
11931196
aiplatform.init(project=_TEST_PROJECT)
11941197

1195-
my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1198+
my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
11961199
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
11971200
)
11981201

1199-
my_pubic_index_endpoint.find_neighbors(
1202+
my_public_index_endpoint.find_neighbors(
12001203
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
12011204
queries=_TEST_QUERIES,
12021205
num_neighbors=_TEST_NUM_NEIGHBOURS,
@@ -1208,7 +1211,7 @@ def test_index_public_endpoint_match_queries(
12081211
)
12091212

12101213
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
1211-
index_endpoint=my_pubic_index_endpoint.resource_name,
1214+
index_endpoint=my_public_index_endpoint.resource_name,
12121215
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
12131216
queries=[
12141217
gca_match_service_v1beta1.FindNeighborsRequest.Query(
@@ -1241,11 +1244,11 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering(
12411244
):
12421245
aiplatform.init(project=_TEST_PROJECT)
12431246

1244-
my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1247+
my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
12451248
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
12461249
)
12471250

1248-
my_pubic_index_endpoint.find_neighbors(
1251+
my_public_index_endpoint.find_neighbors(
12491252
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
12501253
queries=_TEST_QUERIES,
12511254
num_neighbors=_TEST_NUM_NEIGHBOURS,
@@ -1258,7 +1261,7 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering(
12581261
)
12591262

12601263
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
1261-
index_endpoint=my_pubic_index_endpoint.resource_name,
1264+
index_endpoint=my_public_index_endpoint.resource_name,
12621265
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
12631266
queries=[
12641267
gca_match_service_v1beta1.FindNeighborsRequest.Query(
@@ -1337,18 +1340,18 @@ def test_index_public_endpoint_read_index_datapoints(
13371340
):
13381341
aiplatform.init(project=_TEST_PROJECT)
13391342

1340-
my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1343+
my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
13411344
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
13421345
)
13431346

1344-
my_pubic_index_endpoint.read_index_datapoints(
1347+
my_public_index_endpoint.read_index_datapoints(
13451348
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
13461349
ids=_TEST_IDS,
13471350
)
13481351

13491352
read_index_datapoints_request = (
13501353
gca_match_service_v1beta1.ReadIndexDatapointsRequest(
1351-
index_endpoint=my_pubic_index_endpoint.resource_name,
1354+
index_endpoint=my_public_index_endpoint.resource_name,
13521355
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
13531356
ids=_TEST_IDS,
13541357
)

0 commit comments

Comments
 (0)