Skip to content

Commit c5121da

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: use per-method mocks for gRPC service stubs for test_matching_engine_index_endpoint unit test
PiperOrigin-RevId: 696998319
1 parent 91f85ac commit c5121da

File tree

1 file changed

+42
-25
lines changed

1 file changed

+42
-25
lines changed

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from google.cloud import aiplatform
2525
from google.cloud.aiplatform import base
2626
from google.cloud.aiplatform import initializer
27-
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
27+
from google.cloud.aiplatform.matching_engine._protos import (
28+
match_service_pb2,
29+
match_service_pb2_grpc,
30+
)
2831
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
2932
Namespace,
3033
NumericNamespace,
@@ -272,6 +275,9 @@
272275
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
273276
_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"]
274277
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS = "10.128.0.5"
278+
_TEST_PRIVATE_SERVICE_CONNECT_URI = "{}:10000".format(
279+
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
280+
)
275281
_TEST_READ_INDEX_DATAPOINTS_RESPONSE = [
276282
gca_index_v1beta1.IndexDatapoint(
277283
datapoint_id="1",
@@ -543,13 +549,20 @@ def create_index_endpoint_mock():
543549
yield create_index_endpoint_mock
544550

545551

552+
@pytest.fixture
553+
def grpc_insecure_channel_mock():
554+
with patch.object(grpc, "insecure_channel") as grpc_insecure_channel_mock:
555+
grpc_insecure_channel_mock.return_value = mock.Mock()
556+
yield grpc_insecure_channel_mock
557+
558+
546559
@pytest.fixture
547560
def index_endpoint_match_queries_mock():
548561
with patch.object(
549-
grpc._channel._UnaryUnaryMultiCallable,
550-
"__call__",
551-
) as index_endpoint_match_queries_mock:
552-
index_endpoint_match_queries_mock.return_value = (
562+
match_service_pb2_grpc, "MatchServiceStub"
563+
) as match_service_stub_mock:
564+
match_service_stub_mock = match_service_stub_mock.return_value
565+
match_service_stub_mock.BatchMatch.return_value = (
553566
match_service_pb2.BatchMatchResponse(
554567
responses=[
555568
match_service_pb2.BatchMatchResponse.BatchMatchResponsePerIndex(
@@ -595,16 +608,16 @@ def index_endpoint_match_queries_mock():
595608
]
596609
)
597610
)
598-
yield index_endpoint_match_queries_mock
611+
yield match_service_stub_mock
599612

600613

601614
@pytest.fixture
602615
def index_endpoint_batch_get_embeddings_mock():
603616
with patch.object(
604-
grpc._channel._UnaryUnaryMultiCallable,
605-
"__call__",
606-
) as index_endpoint_batch_get_embeddings_mock:
607-
index_endpoint_batch_get_embeddings_mock.return_value = (
617+
match_service_pb2_grpc, "MatchServiceStub"
618+
) as match_service_stub_mock:
619+
match_service_stub_mock = match_service_stub_mock.return_value
620+
match_service_stub_mock.BatchGetEmbeddings.return_value = (
608621
match_service_pb2.BatchGetEmbeddingsResponse(
609622
embeddings=[
610623
match_service_pb2.Embedding(
@@ -626,7 +639,7 @@ def index_endpoint_batch_get_embeddings_mock():
626639
]
627640
)
628641
)
629-
yield index_endpoint_batch_get_embeddings_mock
642+
yield match_service_stub_mock
630643

631644

632645
@pytest.fixture
@@ -1136,7 +1149,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
11361149
]
11371150
)
11381151

1139-
index_endpoint_match_queries_mock.assert_called_with(
1152+
index_endpoint_match_queries_mock.BatchMatch.assert_called_with(
11401153
batch_request, metadata=mock.ANY
11411154
)
11421155

@@ -1203,7 +1216,7 @@ def test_private_service_access_hybrid_search_match_queries(
12031216
]
12041217
)
12051218

1206-
index_endpoint_match_queries_mock.assert_called_with(
1219+
index_endpoint_match_queries_mock.BatchMatch.assert_called_with(
12071220
batch_request, metadata=mock.ANY
12081221
)
12091222

@@ -1257,7 +1270,7 @@ def test_private_service_access_index_endpoint_match_queries(
12571270
]
12581271
)
12591272

1260-
index_endpoint_match_queries_mock.assert_called_with(
1273+
index_endpoint_match_queries_mock.BatchMatch.assert_called_with(
12611274
batch_request, metadata=mock.ANY
12621275
)
12631276

@@ -1312,7 +1325,7 @@ def test_private_service_access_index_endpoint_match_queries_with_jwt(
13121325
]
13131326
)
13141327

1315-
index_endpoint_match_queries_mock.assert_called_with(
1328+
index_endpoint_match_queries_mock.BatchMatch.assert_called_with(
13161329
batch_request, metadata=_TEST_AUTHORIZATION_METADATA
13171330
)
13181331

@@ -1364,7 +1377,7 @@ def test_index_private_service_access_endpoint_find_neighbor_queries(
13641377
)
13651378
]
13661379
)
1367-
index_endpoint_match_queries_mock.assert_called_with(
1380+
index_endpoint_match_queries_mock.BatchMatch.assert_called_with(
13681381
batch_match_request, metadata=mock.ANY
13691382
)
13701383

@@ -1417,13 +1430,13 @@ def test_index_private_service_access_endpoint_find_neighbor_queries_with_jwt(
14171430
)
14181431
]
14191432
)
1420-
index_endpoint_match_queries_mock.assert_called_with(
1433+
index_endpoint_match_queries_mock.BatchMatch.assert_called_with(
14211434
batch_match_request, metadata=_TEST_AUTHORIZATION_METADATA
14221435
)
14231436

14241437
@pytest.mark.usefixtures("get_index_endpoint_mock")
14251438
def test_index_private_service_connect_endpoint_match_queries(
1426-
self, index_endpoint_match_queries_mock
1439+
self, grpc_insecure_channel_mock, index_endpoint_match_queries_mock
14271440
):
14281441
aiplatform.init(project=_TEST_PROJECT)
14291442

@@ -1467,10 +1480,12 @@ def test_index_private_service_connect_endpoint_match_queries(
14671480
]
14681481
)
14691482

1470-
index_endpoint_match_queries_mock.assert_called_with(
1483+
index_endpoint_match_queries_mock.BatchMatch.assert_called_with(
14711484
batch_request, metadata=mock.ANY
14721485
)
14731486

1487+
grpc_insecure_channel_mock.assert_called_with(_TEST_PRIVATE_SERVICE_CONNECT_URI)
1488+
14741489
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
14751490
def test_index_public_endpoint_find_neighbors_queries_backward_compatibility(
14761491
self, index_public_endpoint_match_queries_mock
@@ -1787,7 +1802,7 @@ def test_index_endpoint_batch_get_embeddings(
17871802
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
17881803
)
17891804

1790-
index_endpoint_batch_get_embeddings_mock.assert_called_with(
1805+
index_endpoint_batch_get_embeddings_mock.BatchGetEmbeddings.assert_called_with(
17911806
batch_request, metadata=mock.ANY
17921807
)
17931808

@@ -1809,7 +1824,7 @@ def test_index_endpoint_read_index_datapoints_for_private_service_access(
18091824
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
18101825
)
18111826

1812-
index_endpoint_batch_get_embeddings_mock.assert_called_with(
1827+
index_endpoint_batch_get_embeddings_mock.BatchGetEmbeddings.assert_called_with(
18131828
batch_request, metadata=mock.ANY
18141829
)
18151830

@@ -1835,23 +1850,23 @@ def test_index_endpoint_read_index_datapoints_for_private_service_access_with_jw
18351850
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
18361851
)
18371852

1838-
index_endpoint_batch_get_embeddings_mock.assert_called_with(
1853+
index_endpoint_batch_get_embeddings_mock.BatchGetEmbeddings.assert_called_with(
18391854
batch_request, metadata=_TEST_AUTHORIZATION_METADATA
18401855
)
18411856

18421857
assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE
18431858

18441859
@pytest.mark.usefixtures("get_index_endpoint_mock")
18451860
def test_index_endpoint_read_index_datapoints_for_private_service_connect(
1846-
self, index_endpoint_batch_get_embeddings_mock
1861+
self, grpc_insecure_channel_mock, index_endpoint_batch_get_embeddings_mock
18471862
):
18481863
aiplatform.init(project=_TEST_PROJECT)
18491864

18501865
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
18511866
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
18521867
)
18531868

1854-
my_index_endpoint.private_service_connect_ip = (
1869+
my_index_endpoint.private_service_connect_ip_address = (
18551870
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
18561871
)
18571872
response = my_index_endpoint.read_index_datapoints(
@@ -1863,10 +1878,12 @@ def test_index_endpoint_read_index_datapoints_for_private_service_connect(
18631878
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
18641879
)
18651880

1866-
index_endpoint_batch_get_embeddings_mock.assert_called_with(
1881+
index_endpoint_batch_get_embeddings_mock.BatchGetEmbeddings.assert_called_with(
18671882
batch_request, metadata=mock.ANY
18681883
)
18691884

1885+
grpc_insecure_channel_mock.assert_called_with(_TEST_PRIVATE_SERVICE_CONNECT_URI)
1886+
18701887
assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE
18711888

18721889

0 commit comments

Comments
 (0)