232
232
Namespace (name = "class" , allow_tokens = ["token_1" ], deny_tokens = ["token_2" ])
233
233
]
234
234
_TEST_IDS = ["123" , "456" , "789" ]
235
+ _TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
236
+ _TEST_APPROX_NUM_NEIGHBORS = 2
235
237
236
238
237
239
def uuid_mock ():
@@ -853,6 +855,47 @@ def test_delete_index_endpoint_with_force(
853
855
name = _TEST_INDEX_ENDPOINT_NAME
854
856
)
855
857
858
+ @pytest .mark .usefixtures ("get_index_endpoint_mock" )
859
+ def test_index_endpoint_match_queries_backward_compatibility (
860
+ self , index_endpoint_match_queries_mock
861
+ ):
862
+ aiplatform .init (project = _TEST_PROJECT )
863
+
864
+ my_index_endpoint = aiplatform .MatchingEngineIndexEndpoint (
865
+ index_endpoint_name = _TEST_INDEX_ENDPOINT_ID
866
+ )
867
+
868
+ my_index_endpoint .match (
869
+ _TEST_DEPLOYED_INDEX_ID ,
870
+ _TEST_QUERIES ,
871
+ _TEST_NUM_NEIGHBOURS ,
872
+ _TEST_FILTER ,
873
+ )
874
+
875
+ batch_request = match_service_pb2 .BatchMatchRequest (
876
+ requests = [
877
+ match_service_pb2 .BatchMatchRequest .BatchMatchRequestPerIndex (
878
+ deployed_index_id = _TEST_DEPLOYED_INDEX_ID ,
879
+ requests = [
880
+ match_service_pb2 .MatchRequest (
881
+ num_neighbors = _TEST_NUM_NEIGHBOURS ,
882
+ deployed_index_id = _TEST_DEPLOYED_INDEX_ID ,
883
+ float_val = _TEST_QUERIES [0 ],
884
+ restricts = [
885
+ match_service_pb2 .Namespace (
886
+ name = "class" ,
887
+ allow_tokens = ["token_1" ],
888
+ deny_tokens = ["token_2" ],
889
+ )
890
+ ],
891
+ )
892
+ ],
893
+ )
894
+ ]
895
+ )
896
+
897
+ index_endpoint_match_queries_mock .assert_called_with (batch_request )
898
+
856
899
@pytest .mark .usefixtures ("get_index_endpoint_mock" )
857
900
def test_index_endpoint_match_queries (self , index_endpoint_match_queries_mock ):
858
901
aiplatform .init (project = _TEST_PROJECT )
@@ -866,6 +909,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
866
909
queries = _TEST_QUERIES ,
867
910
num_neighbors = _TEST_NUM_NEIGHBOURS ,
868
911
filter = _TEST_FILTER ,
912
+ per_crowding_attribute_num_neighbors = _TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS ,
913
+ approx_num_neighbors = _TEST_APPROX_NUM_NEIGHBORS ,
869
914
)
870
915
871
916
batch_request = match_service_pb2 .BatchMatchRequest (
@@ -884,6 +929,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
884
929
deny_tokens = ["token_2" ],
885
930
)
886
931
],
932
+ per_crowding_attribute_num_neighbors = _TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS ,
933
+ approx_num_neighbors = _TEST_APPROX_NUM_NEIGHBORS ,
887
934
)
888
935
],
889
936
)
0 commit comments