Skip to content

Commit d591d3e

Browse files
feat: Support filters in matching engine vector matching (#1608)
* feat: support filter in index_enpoint.match() * fix type error * Add unit test for index_endpoint.match() * update docstring example * Update docstring Co-authored-by: nayaknishant <[email protected]>
1 parent 66b5471 commit d591d3e

File tree

3 files changed

+136
-12
lines changed

3 files changed

+136
-12
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from dataclasses import dataclass
18+
from dataclasses import dataclass, field
1919
from typing import Dict, List, Optional, Sequence, Tuple
2020

2121
from google.auth import credentials as auth_credentials
@@ -51,6 +51,25 @@ class MatchNeighbor:
5151
distance: float
5252

5353

54+
@dataclass
55+
class Namespace:
56+
"""Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces.
57+
Args:
58+
name (str):
59+
Required. The name of this Namespace.
60+
allow_tokens (List(str)):
61+
Optional. The allowed tokens in the namespace.
62+
deny_tokens (List(str)):
63+
Optional. The denied tokens in the namespace. When a token is denied, then matches will be excluded whenever the other datapoint has that token.
64+
For example, if a query specifies [Namespace("color", ["red","blue"], ["purple"])], then that query will match datapoints that are red or blue,
65+
but if those points are also purple, then they will be excluded even if they are red/blue.
66+
"""
67+
68+
name: str
69+
allow_tokens: list = field(default_factory=list)
70+
deny_tokens: list = field(default_factory=list)
71+
72+
5473
class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
5574
"""Matching Engine index endpoint resource for Vertex AI."""
5675

@@ -796,7 +815,11 @@ def description(self) -> str:
796815
return self._gca_resource.description
797816

798817
def match(
799-
self, deployed_index_id: str, queries: List[List[float]], num_neighbors: int = 1
818+
self,
819+
deployed_index_id: str,
820+
queries: List[List[float]],
821+
num_neighbors: int = 1,
822+
filter: Optional[List[Namespace]] = [],
800823
) -> List[List[MatchNeighbor]]:
801824
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index.
802825
@@ -808,6 +831,11 @@ def match(
808831
num_neighbors (int):
809832
Required. The number of nearest neighbors to be retrieved from database for
810833
each query.
834+
filter (List[Namespace]):
835+
Optional. A list of Namespaces for filtering the matching results.
836+
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
837+
that satisfy "red color" but not include datapoints with "squared shape".
838+
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
811839
812840
Returns:
813841
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
@@ -836,16 +864,22 @@ def match(
836864
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex()
837865
)
838866
batch_request_for_index.deployed_index_id = deployed_index_id
839-
batch_request_for_index.requests.extend(
840-
[
841-
match_service_pb2.MatchRequest(
842-
num_neighbors=num_neighbors,
843-
deployed_index_id=deployed_index_id,
844-
float_val=query,
845-
)
846-
for query in queries
847-
]
848-
)
867+
requests = []
868+
for query in queries:
869+
request = match_service_pb2.MatchRequest(
870+
num_neighbors=num_neighbors,
871+
deployed_index_id=deployed_index_id,
872+
float_val=query,
873+
)
874+
for namespace in filter:
875+
restrict = match_service_pb2.Namespace()
876+
restrict.name = namespace.name
877+
restrict.allow_tokens.extend(namespace.allow_tokens)
878+
restrict.deny_tokens.extend(namespace.deny_tokens)
879+
request.restricts.append(restrict)
880+
requests.append(request)
881+
882+
batch_request_for_index.requests.extend(requests)
849883
batch_request.requests.append(batch_request_for_index)
850884

851885
# Perform the request

tests/system/aiplatform/test_matching_engine_index.py

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import uuid
1919

2020
from google.cloud import aiplatform
21+
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
22+
Namespace,
23+
)
2124

2225
from tests.system.aiplatform import e2e_base
2326

@@ -161,6 +164,8 @@
161164
-0.021106,
162165
]
163166

167+
_TEST_FILTER = [Namespace("name", ["allow_token"], ["deny_token"])]
168+
164169

165170
class TestMatchingEngine(e2e_base.TestEndToEnd):
166171

@@ -283,6 +288,16 @@ def test_create_get_list_matching_engine_index(self, shared_state):
283288

284289
# assert results[0][0].id == 870
285290

291+
# TODO: Test `my_index_endpoint.match` with filter.
292+
# This requires uploading a new content of the Matching Engine Index to Cloud Storage.
293+
# results = my_index_endpoint.match(
294+
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
295+
# queries=[_TEST_MATCH_QUERY],
296+
# num_neighbors=1,
297+
# filter=_TEST_FILTER,
298+
# )
299+
# assert results[0][0].id == 9999
300+
286301
# Undeploy index
287302
my_index_endpoint = my_index_endpoint.undeploy_index(
288303
deployed_index_id=deployed_index.id

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+75
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from google.cloud import aiplatform
2525
from google.cloud.aiplatform import base
2626
from google.cloud.aiplatform import initializer
27+
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
28+
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
29+
Namespace,
30+
)
2731
from google.cloud.aiplatform.compat.types import (
2832
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
2933
index_endpoint as gca_index_endpoint,
@@ -37,6 +41,8 @@
3741

3842
from google.protobuf import field_mask_pb2
3943

44+
import grpc
45+
4046
import pytest
4147

4248
# project
@@ -210,6 +216,9 @@
210216
]
211217
]
212218
_TEST_NUM_NEIGHBOURS = 1
219+
_TEST_FILTER = [
220+
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
221+
]
213222

214223

215224
def uuid_mock():
@@ -380,6 +389,33 @@ def create_index_endpoint_mock():
380389
yield create_index_endpoint_mock
381390

382391

392+
@pytest.fixture
393+
def index_endpoint_match_queries_mock():
394+
with patch.object(
395+
grpc._channel._UnaryUnaryMultiCallable,
396+
"__call__",
397+
) as index_endpoint_match_queries_mock:
398+
index_endpoint_match_queries_mock.return_value = (
399+
match_service_pb2.BatchMatchResponse(
400+
responses=[
401+
match_service_pb2.BatchMatchResponse.BatchMatchResponsePerIndex(
402+
deployed_index_id="1",
403+
responses=[
404+
match_service_pb2.MatchResponse(
405+
neighbor=[
406+
match_service_pb2.MatchResponse.Neighbor(
407+
id="1", distance=0.1
408+
)
409+
]
410+
)
411+
],
412+
)
413+
]
414+
)
415+
)
416+
yield index_endpoint_match_queries_mock
417+
418+
383419
@pytest.mark.usefixtures("google_auth_mock")
384420
class TestMatchingEngineIndexEndpoint:
385421
def setup_method(self):
@@ -617,3 +653,42 @@ def test_delete_index_endpoint_with_force(
617653
delete_index_endpoint_mock.assert_called_once_with(
618654
name=_TEST_INDEX_ENDPOINT_NAME
619655
)
656+
657+
@pytest.mark.usefixtures("get_index_endpoint_mock")
658+
def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
659+
aiplatform.init(project=_TEST_PROJECT)
660+
661+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
662+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
663+
)
664+
665+
my_index_endpoint.match(
666+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
667+
queries=_TEST_QUERIES,
668+
num_neighbors=_TEST_NUM_NEIGHBOURS,
669+
filter=_TEST_FILTER,
670+
)
671+
672+
batch_request = match_service_pb2.BatchMatchRequest(
673+
requests=[
674+
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
675+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
676+
requests=[
677+
match_service_pb2.MatchRequest(
678+
num_neighbors=_TEST_NUM_NEIGHBOURS,
679+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
680+
float_val=_TEST_QUERIES[0],
681+
restricts=[
682+
match_service_pb2.Namespace(
683+
name="class",
684+
allow_tokens=["token_1"],
685+
deny_tokens=["token_2"],
686+
)
687+
],
688+
)
689+
],
690+
)
691+
]
692+
)
693+
694+
index_endpoint_match_queries_mock.assert_called_with(batch_request)

0 commit comments

Comments
 (0)