24
24
from google .cloud import aiplatform
25
25
from google .cloud .aiplatform import base
26
26
from google .cloud .aiplatform import initializer
27
- from google .cloud .aiplatform .matching_engine ._protos import match_service_pb2
27
+ from google .cloud .aiplatform .matching_engine ._protos import (
28
+ match_service_pb2 ,
29
+ match_service_pb2_grpc ,
30
+ )
28
31
from google .cloud .aiplatform .matching_engine .matching_engine_index_endpoint import (
29
32
Namespace ,
30
33
NumericNamespace ,
272
275
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
273
276
_TEST_PROJECT_ALLOWLIST = ["project-1" , "project-2" ]
274
277
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS = "10.128.0.5"
278
+ _TEST_PRIVATE_SERVICE_CONNECT_URI = "{}:10000" .format (
279
+ _TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
280
+ )
275
281
_TEST_READ_INDEX_DATAPOINTS_RESPONSE = [
276
282
gca_index_v1beta1 .IndexDatapoint (
277
283
datapoint_id = "1" ,
@@ -543,13 +549,20 @@ def create_index_endpoint_mock():
543
549
yield create_index_endpoint_mock
544
550
545
551
552
+ @pytest .fixture
553
+ def grpc_insecure_channel_mock ():
554
+ with patch .object (grpc , "insecure_channel" ) as grpc_insecure_channel_mock :
555
+ grpc_insecure_channel_mock .return_value = mock .Mock ()
556
+ yield grpc_insecure_channel_mock
557
+
558
+
546
559
@pytest .fixture
547
560
def index_endpoint_match_queries_mock ():
548
561
with patch .object (
549
- grpc . _channel . _UnaryUnaryMultiCallable ,
550
- "__call__" ,
551
- ) as index_endpoint_match_queries_mock :
552
- index_endpoint_match_queries_mock .return_value = (
562
+ match_service_pb2_grpc , "MatchServiceStub"
563
+ ) as match_service_stub_mock :
564
+ match_service_stub_mock = match_service_stub_mock . return_value
565
+ match_service_stub_mock . BatchMatch .return_value = (
553
566
match_service_pb2 .BatchMatchResponse (
554
567
responses = [
555
568
match_service_pb2 .BatchMatchResponse .BatchMatchResponsePerIndex (
@@ -595,16 +608,16 @@ def index_endpoint_match_queries_mock():
595
608
]
596
609
)
597
610
)
598
- yield index_endpoint_match_queries_mock
611
+ yield match_service_stub_mock
599
612
600
613
601
614
@pytest .fixture
602
615
def index_endpoint_batch_get_embeddings_mock ():
603
616
with patch .object (
604
- grpc . _channel . _UnaryUnaryMultiCallable ,
605
- "__call__" ,
606
- ) as index_endpoint_batch_get_embeddings_mock :
607
- index_endpoint_batch_get_embeddings_mock .return_value = (
617
+ match_service_pb2_grpc , "MatchServiceStub"
618
+ ) as match_service_stub_mock :
619
+ match_service_stub_mock = match_service_stub_mock . return_value
620
+ match_service_stub_mock . BatchGetEmbeddings .return_value = (
608
621
match_service_pb2 .BatchGetEmbeddingsResponse (
609
622
embeddings = [
610
623
match_service_pb2 .Embedding (
@@ -626,7 +639,7 @@ def index_endpoint_batch_get_embeddings_mock():
626
639
]
627
640
)
628
641
)
629
- yield index_endpoint_batch_get_embeddings_mock
642
+ yield match_service_stub_mock
630
643
631
644
632
645
@pytest .fixture
@@ -1136,7 +1149,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
1136
1149
]
1137
1150
)
1138
1151
1139
- index_endpoint_match_queries_mock .assert_called_with (
1152
+ index_endpoint_match_queries_mock .BatchMatch . assert_called_with (
1140
1153
batch_request , metadata = mock .ANY
1141
1154
)
1142
1155
@@ -1203,7 +1216,7 @@ def test_private_service_access_hybrid_search_match_queries(
1203
1216
]
1204
1217
)
1205
1218
1206
- index_endpoint_match_queries_mock .assert_called_with (
1219
+ index_endpoint_match_queries_mock .BatchMatch . assert_called_with (
1207
1220
batch_request , metadata = mock .ANY
1208
1221
)
1209
1222
@@ -1257,7 +1270,7 @@ def test_private_service_access_index_endpoint_match_queries(
1257
1270
]
1258
1271
)
1259
1272
1260
- index_endpoint_match_queries_mock .assert_called_with (
1273
+ index_endpoint_match_queries_mock .BatchMatch . assert_called_with (
1261
1274
batch_request , metadata = mock .ANY
1262
1275
)
1263
1276
@@ -1312,7 +1325,7 @@ def test_private_service_access_index_endpoint_match_queries_with_jwt(
1312
1325
]
1313
1326
)
1314
1327
1315
- index_endpoint_match_queries_mock .assert_called_with (
1328
+ index_endpoint_match_queries_mock .BatchMatch . assert_called_with (
1316
1329
batch_request , metadata = _TEST_AUTHORIZATION_METADATA
1317
1330
)
1318
1331
@@ -1364,7 +1377,7 @@ def test_index_private_service_access_endpoint_find_neighbor_queries(
1364
1377
)
1365
1378
]
1366
1379
)
1367
- index_endpoint_match_queries_mock .assert_called_with (
1380
+ index_endpoint_match_queries_mock .BatchMatch . assert_called_with (
1368
1381
batch_match_request , metadata = mock .ANY
1369
1382
)
1370
1383
@@ -1417,13 +1430,13 @@ def test_index_private_service_access_endpoint_find_neighbor_queries_with_jwt(
1417
1430
)
1418
1431
]
1419
1432
)
1420
- index_endpoint_match_queries_mock .assert_called_with (
1433
+ index_endpoint_match_queries_mock .BatchMatch . assert_called_with (
1421
1434
batch_match_request , metadata = _TEST_AUTHORIZATION_METADATA
1422
1435
)
1423
1436
1424
1437
@pytest .mark .usefixtures ("get_index_endpoint_mock" )
1425
1438
def test_index_private_service_connect_endpoint_match_queries (
1426
- self , index_endpoint_match_queries_mock
1439
+ self , grpc_insecure_channel_mock , index_endpoint_match_queries_mock
1427
1440
):
1428
1441
aiplatform .init (project = _TEST_PROJECT )
1429
1442
@@ -1467,10 +1480,12 @@ def test_index_private_service_connect_endpoint_match_queries(
1467
1480
]
1468
1481
)
1469
1482
1470
- index_endpoint_match_queries_mock .assert_called_with (
1483
+ index_endpoint_match_queries_mock .BatchMatch . assert_called_with (
1471
1484
batch_request , metadata = mock .ANY
1472
1485
)
1473
1486
1487
+ grpc_insecure_channel_mock .assert_called_with (_TEST_PRIVATE_SERVICE_CONNECT_URI )
1488
+
1474
1489
@pytest .mark .usefixtures ("get_index_public_endpoint_mock" )
1475
1490
def test_index_public_endpoint_find_neighbors_queries_backward_compatibility (
1476
1491
self , index_public_endpoint_match_queries_mock
@@ -1787,7 +1802,7 @@ def test_index_endpoint_batch_get_embeddings(
1787
1802
deployed_index_id = _TEST_DEPLOYED_INDEX_ID , id = ["1" , "2" ]
1788
1803
)
1789
1804
1790
- index_endpoint_batch_get_embeddings_mock .assert_called_with (
1805
+ index_endpoint_batch_get_embeddings_mock .BatchGetEmbeddings . assert_called_with (
1791
1806
batch_request , metadata = mock .ANY
1792
1807
)
1793
1808
@@ -1809,7 +1824,7 @@ def test_index_endpoint_read_index_datapoints_for_private_service_access(
1809
1824
deployed_index_id = _TEST_DEPLOYED_INDEX_ID , id = ["1" , "2" ]
1810
1825
)
1811
1826
1812
- index_endpoint_batch_get_embeddings_mock .assert_called_with (
1827
+ index_endpoint_batch_get_embeddings_mock .BatchGetEmbeddings . assert_called_with (
1813
1828
batch_request , metadata = mock .ANY
1814
1829
)
1815
1830
@@ -1835,23 +1850,23 @@ def test_index_endpoint_read_index_datapoints_for_private_service_access_with_jw
1835
1850
deployed_index_id = _TEST_DEPLOYED_INDEX_ID , id = ["1" , "2" ]
1836
1851
)
1837
1852
1838
- index_endpoint_batch_get_embeddings_mock .assert_called_with (
1853
+ index_endpoint_batch_get_embeddings_mock .BatchGetEmbeddings . assert_called_with (
1839
1854
batch_request , metadata = _TEST_AUTHORIZATION_METADATA
1840
1855
)
1841
1856
1842
1857
assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE
1843
1858
1844
1859
@pytest .mark .usefixtures ("get_index_endpoint_mock" )
1845
1860
def test_index_endpoint_read_index_datapoints_for_private_service_connect (
1846
- self , index_endpoint_batch_get_embeddings_mock
1861
+ self , grpc_insecure_channel_mock , index_endpoint_batch_get_embeddings_mock
1847
1862
):
1848
1863
aiplatform .init (project = _TEST_PROJECT )
1849
1864
1850
1865
my_index_endpoint = aiplatform .MatchingEngineIndexEndpoint (
1851
1866
index_endpoint_name = _TEST_INDEX_ENDPOINT_ID
1852
1867
)
1853
1868
1854
- my_index_endpoint .private_service_connect_ip = (
1869
+ my_index_endpoint .private_service_connect_ip_address = (
1855
1870
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
1856
1871
)
1857
1872
response = my_index_endpoint .read_index_datapoints (
@@ -1863,10 +1878,12 @@ def test_index_endpoint_read_index_datapoints_for_private_service_connect(
1863
1878
deployed_index_id = _TEST_DEPLOYED_INDEX_ID , id = ["1" , "2" ]
1864
1879
)
1865
1880
1866
- index_endpoint_batch_get_embeddings_mock .assert_called_with (
1881
+ index_endpoint_batch_get_embeddings_mock .BatchGetEmbeddings . assert_called_with (
1867
1882
batch_request , metadata = mock .ANY
1868
1883
)
1869
1884
1885
+ grpc_insecure_channel_mock .assert_called_with (_TEST_PRIVATE_SERVICE_CONNECT_URI )
1886
+
1870
1887
assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE
1871
1888
1872
1889
0 commit comments