Skip to content

Commit 679646a

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Add numeric_filter to MatchingEngineIndexEndpoint match() and find_neighbor() private endpoint queries.
PiperOrigin-RevId: 602661540
1 parent 512b82d commit 679646a

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+25
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ def find_neighbors(
12731273
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
12741274
approx_num_neighbors=approx_num_neighbors,
12751275
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1276+
numeric_filter=numeric_filter,
12761277
)
12771278

12781279
# Create the FindNeighbors request
@@ -1456,6 +1457,7 @@ def match(
14561457
approx_num_neighbors: Optional[int] = None,
14571458
fraction_leaf_nodes_to_search_override: Optional[float] = None,
14581459
low_level_batch_size: int = 0,
1460+
numeric_filter: Optional[List[NumericNamespace]] = None,
14591461
) -> List[List[MatchNeighbor]]:
14601462
"""Retrieves nearest neighbors for the given embedding queries on the
14611463
specified deployed index for private endpoint only.
@@ -1494,6 +1496,11 @@ def match(
14941496
This field is optional, defaults to 0 if not set. A non-positive
14951497
number disables low level batching, i.e. all queries are
14961498
executed sequentially.
1499+
numeric_filter (Optional[list[NumericNamespace]]):
1500+
Optional. A list of NumericNamespaces for filtering the matching
1501+
results. For example:
1502+
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
1503+
will match datapoints that its cost is greater than 5.
14971504
14981505
Returns:
14991506
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
@@ -1513,13 +1520,30 @@ def match(
15131520

15141521
# Preprocess restricts to be used for each request
15151522
restricts = []
1523+
# Token restricts
15161524
if filter:
15171525
for namespace in filter:
15181526
restrict = match_service_pb2.Namespace()
15191527
restrict.name = namespace.name
15201528
restrict.allow_tokens.extend(namespace.allow_tokens)
15211529
restrict.deny_tokens.extend(namespace.deny_tokens)
15221530
restricts.append(restrict)
1531+
numeric_restricts = []
1532+
# Numeric restricts
1533+
if numeric_filter:
1534+
for numeric_namespace in numeric_filter:
1535+
numeric_restrict = match_service_pb2.NumericNamespace()
1536+
numeric_restrict.name = numeric_namespace.name
1537+
numeric_restrict.op = match_service_pb2.NumericNamespace.Operator.Value(
1538+
numeric_namespace.op
1539+
)
1540+
if numeric_namespace.value_int is not None:
1541+
numeric_restrict.value_int = numeric_namespace.value_int
1542+
if numeric_namespace.value_float is not None:
1543+
numeric_restrict.value_float = numeric_namespace.value_float
1544+
if numeric_namespace.value_double is not None:
1545+
numeric_restrict.value_double = numeric_namespace.value_double
1546+
numeric_restricts.append(numeric_restrict)
15231547

15241548
requests = []
15251549
if queries:
@@ -1532,6 +1556,7 @@ def match(
15321556
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
15331557
approx_num_neighbors=approx_num_neighbors,
15341558
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1559+
numeric_restricts=numeric_restricts,
15351560
)
15361561
requests.append(request)
15371562
else:

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,13 @@
237237
]
238238
_TEST_NUMERIC_FILTER = [
239239
NumericNamespace(name="cost", value_double=0.3, op="EQUAL"),
240-
NumericNamespace(name="size", value_int=10, op="GREATER"),
241-
NumericNamespace(name="seconds", value_float=20.5, op="LESS_EQUAL"),
240+
NumericNamespace(name="size", value_int=0, op="GREATER"),
241+
NumericNamespace(name="seconds", value_float=-20.5, op="LESS_EQUAL"),
242+
]
243+
_TEST_NUMERIC_NAMESPACE = [
244+
match_service_pb2.NumericNamespace(name="cost", value_double=0.3, op=3),
245+
match_service_pb2.NumericNamespace(name="size", value_int=0, op=5),
246+
match_service_pb2.NumericNamespace(name="seconds", value_float=-20.5, op=2),
242247
]
243248
_TEST_IDS = ["123", "456", "789"]
244249
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
@@ -1045,7 +1050,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
10451050
index_endpoint_match_queries_mock.assert_called_with(batch_request)
10461051

10471052
@pytest.mark.usefixtures("get_index_endpoint_mock")
1048-
def test_private_service_access_index_endpoint_match_queries(
1053+
def test_private_service_access_service_access_index_endpoint_match_queries(
10491054
self, index_endpoint_match_queries_mock
10501055
):
10511056
aiplatform.init(project=_TEST_PROJECT)
@@ -1063,6 +1068,7 @@ def test_private_service_access_index_endpoint_match_queries(
10631068
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
10641069
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
10651070
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
1071+
numeric_filter=_TEST_NUMERIC_FILTER,
10661072
)
10671073

10681074
batch_request = match_service_pb2.BatchMatchRequest(
@@ -1085,6 +1091,7 @@ def test_private_service_access_index_endpoint_match_queries(
10851091
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
10861092
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
10871093
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1094+
numeric_restricts=_TEST_NUMERIC_NAMESPACE,
10881095
)
10891096
for i in range(len(_TEST_QUERIES))
10901097
],
@@ -1095,7 +1102,7 @@ def test_private_service_access_index_endpoint_match_queries(
10951102
index_endpoint_match_queries_mock.assert_called_with(batch_request)
10961103

10971104
@pytest.mark.usefixtures("get_index_endpoint_mock")
1098-
def test_private_index_endpoint_find_neighbor_queries(
1105+
def test_index_private_service_access_endpoint_find_neighbor_queries(
10991106
self, index_endpoint_match_queries_mock
11001107
):
11011108
aiplatform.init(project=_TEST_PROJECT)
@@ -1113,6 +1120,7 @@ def test_private_index_endpoint_find_neighbor_queries(
11131120
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
11141121
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
11151122
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
1123+
numeric_filter=_TEST_NUMERIC_FILTER,
11161124
)
11171125

11181126
batch_match_request = match_service_pb2.BatchMatchRequest(
@@ -1134,6 +1142,7 @@ def test_private_index_endpoint_find_neighbor_queries(
11341142
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
11351143
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
11361144
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1145+
numeric_restricts=_TEST_NUMERIC_NAMESPACE,
11371146
)
11381147
for test_query in _TEST_QUERIES
11391148
],
@@ -1331,10 +1340,10 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering(
13311340
namespace="cost", value_double=0.3, op="EQUAL"
13321341
),
13331342
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
1334-
namespace="size", value_int=10, op="GREATER"
1343+
namespace="size", value_int=0, op="GREATER"
13351344
),
13361345
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
1337-
namespace="seconds", value_float=20.5, op="LESS_EQUAL"
1346+
namespace="seconds", value_float=-20.5, op="LESS_EQUAL"
13381347
),
13391348
],
13401349
),

0 commit comments

Comments
 (0)