237
237
]
238
238
_TEST_NUMERIC_FILTER = [
239
239
NumericNamespace (name = "cost" , value_double = 0.3 , op = "EQUAL" ),
240
- NumericNamespace (name = "size" , value_int = 10 , op = "GREATER" ),
241
- NumericNamespace (name = "seconds" , value_float = 20.5 , op = "LESS_EQUAL" ),
240
+ NumericNamespace (name = "size" , value_int = 0 , op = "GREATER" ),
241
+ NumericNamespace (name = "seconds" , value_float = - 20.5 , op = "LESS_EQUAL" ),
242
+ ]
243
+ _TEST_NUMERIC_NAMESPACE = [
244
+ match_service_pb2 .NumericNamespace (name = "cost" , value_double = 0.3 , op = 3 ),
245
+ match_service_pb2 .NumericNamespace (name = "size" , value_int = 0 , op = 5 ),
246
+ match_service_pb2 .NumericNamespace (name = "seconds" , value_float = - 20.5 , op = 2 ),
242
247
]
243
248
_TEST_IDS = ["123" , "456" , "789" ]
244
249
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
@@ -1045,7 +1050,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
1045
1050
index_endpoint_match_queries_mock .assert_called_with (batch_request )
1046
1051
1047
1052
@pytest .mark .usefixtures ("get_index_endpoint_mock" )
1048
- def test_private_service_access_index_endpoint_match_queries (
1053
+ def test_private_service_access_service_access_index_endpoint_match_queries (
1049
1054
self , index_endpoint_match_queries_mock
1050
1055
):
1051
1056
aiplatform .init (project = _TEST_PROJECT )
@@ -1063,6 +1068,7 @@ def test_private_service_access_index_endpoint_match_queries(
1063
1068
approx_num_neighbors = _TEST_APPROX_NUM_NEIGHBORS ,
1064
1069
fraction_leaf_nodes_to_search_override = _TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE ,
1065
1070
low_level_batch_size = _TEST_LOW_LEVEL_BATCH_SIZE ,
1071
+ numeric_filter = _TEST_NUMERIC_FILTER ,
1066
1072
)
1067
1073
1068
1074
batch_request = match_service_pb2 .BatchMatchRequest (
@@ -1085,6 +1091,7 @@ def test_private_service_access_index_endpoint_match_queries(
1085
1091
per_crowding_attribute_num_neighbors = _TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS ,
1086
1092
approx_num_neighbors = _TEST_APPROX_NUM_NEIGHBORS ,
1087
1093
fraction_leaf_nodes_to_search_override = _TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE ,
1094
+ numeric_restricts = _TEST_NUMERIC_NAMESPACE ,
1088
1095
)
1089
1096
for i in range (len (_TEST_QUERIES ))
1090
1097
],
@@ -1095,7 +1102,7 @@ def test_private_service_access_index_endpoint_match_queries(
1095
1102
index_endpoint_match_queries_mock .assert_called_with (batch_request )
1096
1103
1097
1104
@pytest .mark .usefixtures ("get_index_endpoint_mock" )
1098
- def test_private_index_endpoint_find_neighbor_queries (
1105
+ def test_index_private_service_access_endpoint_find_neighbor_queries (
1099
1106
self , index_endpoint_match_queries_mock
1100
1107
):
1101
1108
aiplatform .init (project = _TEST_PROJECT )
@@ -1113,6 +1120,7 @@ def test_private_index_endpoint_find_neighbor_queries(
1113
1120
approx_num_neighbors = _TEST_APPROX_NUM_NEIGHBORS ,
1114
1121
fraction_leaf_nodes_to_search_override = _TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE ,
1115
1122
return_full_datapoint = _TEST_RETURN_FULL_DATAPOINT ,
1123
+ numeric_filter = _TEST_NUMERIC_FILTER ,
1116
1124
)
1117
1125
1118
1126
batch_match_request = match_service_pb2 .BatchMatchRequest (
@@ -1134,6 +1142,7 @@ def test_private_index_endpoint_find_neighbor_queries(
1134
1142
per_crowding_attribute_num_neighbors = _TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS ,
1135
1143
approx_num_neighbors = _TEST_APPROX_NUM_NEIGHBORS ,
1136
1144
fraction_leaf_nodes_to_search_override = _TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE ,
1145
+ numeric_restricts = _TEST_NUMERIC_NAMESPACE ,
1137
1146
)
1138
1147
for test_query in _TEST_QUERIES
1139
1148
],
@@ -1331,10 +1340,10 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering(
1331
1340
namespace = "cost" , value_double = 0.3 , op = "EQUAL"
1332
1341
),
1333
1342
gca_index_v1beta1 .IndexDatapoint .NumericRestriction (
1334
- namespace = "size" , value_int = 10 , op = "GREATER"
1343
+ namespace = "size" , value_int = 0 , op = "GREATER"
1335
1344
),
1336
1345
gca_index_v1beta1 .IndexDatapoint .NumericRestriction (
1337
- namespace = "seconds" , value_float = 20.5 , op = "LESS_EQUAL"
1346
+ namespace = "seconds" , value_float = - 20.5 , op = "LESS_EQUAL"
1338
1347
),
1339
1348
],
1340
1349
),
0 commit comments