15
15
# limitations under the License.
16
16
#
17
17
18
- from dataclasses import dataclass
18
+ from dataclasses import dataclass , field
19
19
from typing import Dict , List , Optional , Sequence , Tuple
20
20
21
21
from google .auth import credentials as auth_credentials
@@ -51,6 +51,25 @@ class MatchNeighbor:
51
51
distance : float
52
52
53
53
54
+ @dataclass
55
+ class Namespace :
56
+ """Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces.
57
+ Args:
58
+ name (str):
59
+ Required. The name of this Namespace.
60
+ allow_tokens (List(str)):
61
+ Optional. The allowed tokens in the namespace.
62
+ deny_tokens (List(str)):
63
+ Optional. The denied tokens in the namespace. When a token is denied, then matches will be excluded whenever the other datapoint has that token.
64
+ For example, if a query specifies [Namespace("color", ["red","blue"], ["purple"])], then that query will match datapoints that are red or blue,
65
+ but if those points are also purple, then they will be excluded even if they are red/blue.
66
+ """
67
+
68
+ name : str
69
+ allow_tokens : list = field (default_factory = list )
70
+ deny_tokens : list = field (default_factory = list )
71
+
72
+
54
73
class MatchingEngineIndexEndpoint (base .VertexAiResourceNounWithFutureManager ):
55
74
"""Matching Engine index endpoint resource for Vertex AI."""
56
75
@@ -796,7 +815,11 @@ def description(self) -> str:
796
815
return self ._gca_resource .description
797
816
798
817
def match (
799
- self , deployed_index_id : str , queries : List [List [float ]], num_neighbors : int = 1
818
+ self ,
819
+ deployed_index_id : str ,
820
+ queries : List [List [float ]],
821
+ num_neighbors : int = 1 ,
822
+ filter : Optional [List [Namespace ]] = [],
800
823
) -> List [List [MatchNeighbor ]]:
801
824
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index.
802
825
@@ -808,6 +831,11 @@ def match(
808
831
num_neighbors (int):
809
832
Required. The number of nearest neighbors to be retrieved from database for
810
833
each query.
834
+ filter (List[Namespace]):
835
+ Optional. A list of Namespaces for filtering the matching results.
836
+ For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
837
+ that satisfy "red color" but not include datapoints with "squared shape".
838
+ Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
811
839
812
840
Returns:
813
841
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
@@ -836,16 +864,22 @@ def match(
836
864
match_service_pb2 .BatchMatchRequest .BatchMatchRequestPerIndex ()
837
865
)
838
866
batch_request_for_index .deployed_index_id = deployed_index_id
839
- batch_request_for_index .requests .extend (
840
- [
841
- match_service_pb2 .MatchRequest (
842
- num_neighbors = num_neighbors ,
843
- deployed_index_id = deployed_index_id ,
844
- float_val = query ,
845
- )
846
- for query in queries
847
- ]
848
- )
867
+ requests = []
868
+ for query in queries :
869
+ request = match_service_pb2 .MatchRequest (
870
+ num_neighbors = num_neighbors ,
871
+ deployed_index_id = deployed_index_id ,
872
+ float_val = query ,
873
+ )
874
+ for namespace in filter :
875
+ restrict = match_service_pb2 .Namespace ()
876
+ restrict .name = namespace .name
877
+ restrict .allow_tokens .extend (namespace .allow_tokens )
878
+ restrict .deny_tokens .extend (namespace .deny_tokens )
879
+ request .restricts .append (restrict )
880
+ requests .append (request )
881
+
882
+ batch_request_for_index .requests .extend (requests )
849
883
batch_request .requests .append (batch_request_for_index )
850
884
851
885
# Perform the request
0 commit comments