Skip to content

Commit e5c20c3

Browse files
lingyinwcopybara-github
authored andcommitted
feat: add support for per_crowding_attribute_num_neighbors approx_num_neighborsto MatchingEngineIndexEndpoint match()
PiperOrigin-RevId: 579015694
1 parent 650fa62 commit e5c20c3

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+13
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,8 @@ def match(
10731073
queries: List[List[float]],
10741074
num_neighbors: int = 1,
10751075
filter: Optional[List[Namespace]] = [],
1076+
per_crowding_attribute_num_neighbors: Optional[int] = None,
1077+
approx_num_neighbors: Optional[int] = None,
10761078
) -> List[List[MatchNeighbor]]:
10771079
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index.
10781080
@@ -1089,6 +1091,15 @@ def match(
10891091
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
10901092
that satisfy "red color" but not include datapoints with "squared shape".
10911093
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
1094+
per_crowding_attribute_num_neighbors (int):
1095+
Optional. Crowding is a constraint on a neighbor list produced by nearest neighbor
1096+
search requiring that no more than some value k' of the k neighbors
1097+
returned have the same value of crowding_attribute.
1098+
It's used for improving result diversity.
1099+
This field is the maximum number of matches with the same crowding tag.
1100+
approx_num_neighbors (int):
1101+
The number of neighbors to find via approximate search before exact reordering is performed.
1102+
If not set, the default value from scam config is used; if set, this value must be > 0.
10921103
10931104
Returns:
10941105
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
@@ -1123,6 +1134,8 @@ def match(
11231134
num_neighbors=num_neighbors,
11241135
deployed_index_id=deployed_index_id,
11251136
float_val=query,
1137+
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
1138+
approx_num_neighbors=approx_num_neighbors,
11261139
)
11271140
for namespace in filter:
11281141
restrict = match_service_pb2.Namespace()

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+47
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@
232232
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
233233
]
234234
_TEST_IDS = ["123", "456", "789"]
235+
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
236+
_TEST_APPROX_NUM_NEIGHBORS = 2
235237

236238

237239
def uuid_mock():
@@ -853,6 +855,47 @@ def test_delete_index_endpoint_with_force(
853855
name=_TEST_INDEX_ENDPOINT_NAME
854856
)
855857

858+
@pytest.mark.usefixtures("get_index_endpoint_mock")
859+
def test_index_endpoint_match_queries_backward_compatibility(
860+
self, index_endpoint_match_queries_mock
861+
):
862+
aiplatform.init(project=_TEST_PROJECT)
863+
864+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
865+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
866+
)
867+
868+
my_index_endpoint.match(
869+
_TEST_DEPLOYED_INDEX_ID,
870+
_TEST_QUERIES,
871+
_TEST_NUM_NEIGHBOURS,
872+
_TEST_FILTER,
873+
)
874+
875+
batch_request = match_service_pb2.BatchMatchRequest(
876+
requests=[
877+
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
878+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
879+
requests=[
880+
match_service_pb2.MatchRequest(
881+
num_neighbors=_TEST_NUM_NEIGHBOURS,
882+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
883+
float_val=_TEST_QUERIES[0],
884+
restricts=[
885+
match_service_pb2.Namespace(
886+
name="class",
887+
allow_tokens=["token_1"],
888+
deny_tokens=["token_2"],
889+
)
890+
],
891+
)
892+
],
893+
)
894+
]
895+
)
896+
897+
index_endpoint_match_queries_mock.assert_called_with(batch_request)
898+
856899
@pytest.mark.usefixtures("get_index_endpoint_mock")
857900
def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
858901
aiplatform.init(project=_TEST_PROJECT)
@@ -866,6 +909,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
866909
queries=_TEST_QUERIES,
867910
num_neighbors=_TEST_NUM_NEIGHBOURS,
868911
filter=_TEST_FILTER,
912+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
913+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
869914
)
870915

871916
batch_request = match_service_pb2.BatchMatchRequest(
@@ -884,6 +929,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
884929
deny_tokens=["token_2"],
885930
)
886931
],
932+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
933+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
887934
)
888935
],
889936
)

0 commit comments

Comments
 (0)