23
23
24
24
import vertexai
25
25
from google .cloud .aiplatform import initializer
26
+ from google .cloud .aiplatform_v1 import types as types_v1
27
+ from google .cloud .aiplatform_v1 .services import (
28
+ prediction_service as prediction_service_v1 ,
29
+ )
30
+ from google .cloud .aiplatform_v1beta1 import types as types_v1beta1
26
31
from vertexai import generative_models
27
32
from vertexai .preview import (
28
33
generative_models as preview_generative_models ,
@@ -326,6 +331,72 @@ def mock_stream_generate_content(
326
331
yield blocked_chunk
327
332
328
333
334
+ def mock_generate_content_v1 (
335
+ self ,
336
+ request : types_v1 .GenerateContentRequest ,
337
+ * ,
338
+ model : Optional [str ] = None ,
339
+ contents : Optional [MutableSequence [types_v1 .Content ]] = None ,
340
+ ) -> types_v1 .GenerateContentResponse :
341
+ request_v1beta1 = types_v1beta1 .GenerateContentRequest .deserialize (
342
+ type (request ).serialize (request )
343
+ )
344
+ response_v1beta1 = mock_generate_content (
345
+ self = self ,
346
+ request = request_v1beta1 ,
347
+ )
348
+ response_v1 = types_v1 .GenerateContentResponse .deserialize (
349
+ type (response_v1beta1 ).serialize (response_v1beta1 )
350
+ )
351
+ return response_v1
352
+
353
+
354
+ def mock_stream_generate_content_v1 (
355
+ self ,
356
+ request : types_v1 .GenerateContentRequest ,
357
+ * ,
358
+ model : Optional [str ] = None ,
359
+ contents : Optional [MutableSequence [types_v1 .Content ]] = None ,
360
+ ) -> Iterable [types_v1 .GenerateContentResponse ]:
361
+ request_v1beta1 = types_v1beta1 .GenerateContentRequest .deserialize (
362
+ type (request ).serialize (request )
363
+ )
364
+ for response_v1beta1 in mock_stream_generate_content (
365
+ self = self ,
366
+ request = request_v1beta1 ,
367
+ ):
368
+ response_v1 = types_v1 .GenerateContentResponse .deserialize (
369
+ type (response_v1beta1 ).serialize (response_v1beta1 )
370
+ )
371
+ yield response_v1
372
+
373
+
374
+ def patch_genai_services (func : callable ):
375
+ """Patches GenAI services (v1 and v1beta1, streaming and non-streaming)."""
376
+
377
+ func = mock .patch .object (
378
+ target = prediction_service .PredictionServiceClient ,
379
+ attribute = "generate_content" ,
380
+ new = mock_generate_content ,
381
+ )(func )
382
+ func = mock .patch .object (
383
+ target = prediction_service_v1 .PredictionServiceClient ,
384
+ attribute = "generate_content" ,
385
+ new = mock_generate_content_v1 ,
386
+ )(func )
387
+ func = mock .patch .object (
388
+ target = prediction_service .PredictionServiceClient ,
389
+ attribute = "stream_generate_content" ,
390
+ new = mock_stream_generate_content ,
391
+ )(func )
392
+ func = mock .patch .object (
393
+ target = prediction_service_v1 .PredictionServiceClient ,
394
+ attribute = "stream_generate_content" ,
395
+ new = mock_stream_generate_content_v1 ,
396
+ )(func )
397
+ return func
398
+
399
+
329
400
@pytest .fixture
330
401
def mock_get_cached_content_fixture ():
331
402
"""Mocks GenAiCacheServiceClient.get_cached_content()."""
@@ -376,11 +447,6 @@ def setup_method(self):
376
447
def teardown_method (self ):
377
448
initializer .global_pool .shutdown (wait = True )
378
449
379
- @mock .patch .object (
380
- target = prediction_service .PredictionServiceClient ,
381
- attribute = "generate_content" ,
382
- new = mock_generate_content ,
383
- )
384
450
@pytest .mark .parametrize (
385
451
"generative_models" ,
386
452
[generative_models , preview_generative_models ],
@@ -489,11 +555,7 @@ def test_generative_model_from_cached_content_with_resource_name(
489
555
== "cached-content-id-in-from-cached-content-test"
490
556
)
491
557
492
- @mock .patch .object (
493
- target = prediction_service .PredictionServiceClient ,
494
- attribute = "generate_content" ,
495
- new = mock_generate_content ,
496
- )
558
+ @patch_genai_services
497
559
@pytest .mark .parametrize (
498
560
"generative_models" ,
499
561
[generative_models , preview_generative_models ],
@@ -601,11 +663,7 @@ def test_generate_content_with_cached_content(
601
663
602
664
assert response .text == "response to " + cached_content .resource_name
603
665
604
- @mock .patch .object (
605
- target = prediction_service .PredictionServiceClient ,
606
- attribute = "stream_generate_content" ,
607
- new = mock_stream_generate_content ,
608
- )
666
+ @patch_genai_services
609
667
@pytest .mark .parametrize (
610
668
"generative_models" ,
611
669
[generative_models , preview_generative_models ],
@@ -616,11 +674,7 @@ def test_generate_content_streaming(self, generative_models: generative_models):
616
674
for chunk in stream :
617
675
assert chunk .text
618
676
619
- @mock .patch .object (
620
- target = prediction_service .PredictionServiceClient ,
621
- attribute = "generate_content" ,
622
- new = mock_generate_content ,
623
- )
677
+ @patch_genai_services
624
678
@pytest .mark .parametrize (
625
679
"generative_models" ,
626
680
[generative_models , preview_generative_models ],
@@ -668,11 +722,7 @@ def test_generate_content_response_accessor_errors(
668
722
assert e .match ("no text" )
669
723
assert e .match ("function_call" )
670
724
671
- @mock .patch .object (
672
- target = prediction_service .PredictionServiceClient ,
673
- attribute = "generate_content" ,
674
- new = mock_generate_content ,
675
- )
725
+ @patch_genai_services
676
726
@pytest .mark .parametrize (
677
727
"generative_models" ,
678
728
[generative_models , preview_generative_models ],
@@ -685,11 +735,7 @@ def test_chat_send_message(self, generative_models: generative_models):
685
735
response2 = chat .send_message ("Is sky blue on other planets?" )
686
736
assert response2 .text
687
737
688
- @mock .patch .object (
689
- target = prediction_service .PredictionServiceClient ,
690
- attribute = "stream_generate_content" ,
691
- new = mock_stream_generate_content ,
692
- )
738
+ @patch_genai_services
693
739
@pytest .mark .parametrize (
694
740
"generative_models" ,
695
741
[generative_models , preview_generative_models ],
@@ -704,11 +750,7 @@ def test_chat_send_message_streaming(self, generative_models: generative_models)
704
750
for chunk in stream2 :
705
751
assert chunk .candidates
706
752
707
- @mock .patch .object (
708
- target = prediction_service .PredictionServiceClient ,
709
- attribute = "generate_content" ,
710
- new = mock_generate_content ,
711
- )
753
+ @patch_genai_services
712
754
@pytest .mark .parametrize (
713
755
"generative_models" ,
714
756
[generative_models , preview_generative_models ],
@@ -727,11 +769,7 @@ def test_chat_send_message_response_validation_errors(
727
769
# Checking that history did not get updated
728
770
assert len (chat .history ) == 2
729
771
730
- @mock .patch .object (
731
- target = prediction_service .PredictionServiceClient ,
732
- attribute = "generate_content" ,
733
- new = mock_generate_content ,
734
- )
772
+ @patch_genai_services
735
773
@pytest .mark .parametrize (
736
774
"generative_models" ,
737
775
[generative_models , preview_generative_models ],
@@ -754,11 +792,7 @@ def test_chat_send_message_response_blocked_errors(
754
792
# Checking that history did not get updated
755
793
assert len (chat .history ) == 2
756
794
757
- @mock .patch .object (
758
- target = prediction_service .PredictionServiceClient ,
759
- attribute = "generate_content" ,
760
- new = mock_generate_content ,
761
- )
795
+ @patch_genai_services
762
796
@pytest .mark .parametrize (
763
797
"generative_models" ,
764
798
[generative_models , preview_generative_models ],
@@ -775,11 +809,7 @@ def test_chat_send_message_response_candidate_blocked_error(
775
809
# Checking that history did not get updated
776
810
assert not chat .history
777
811
778
- @mock .patch .object (
779
- target = prediction_service .PredictionServiceClient ,
780
- attribute = "generate_content" ,
781
- new = mock_generate_content ,
782
- )
812
+ @patch_genai_services
783
813
@pytest .mark .parametrize (
784
814
"generative_models" ,
785
815
[generative_models , preview_generative_models ],
@@ -808,11 +838,7 @@ def test_finish_reason_max_tokens_in_generate_content_and_send_message(
808
838
# Verify that history did not get updated
809
839
assert not chat .history
810
840
811
- @mock .patch .object (
812
- target = prediction_service .PredictionServiceClient ,
813
- attribute = "generate_content" ,
814
- new = mock_generate_content ,
815
- )
841
+ @patch_genai_services
816
842
@pytest .mark .parametrize (
817
843
"generative_models" ,
818
844
[generative_models , preview_generative_models ],
@@ -861,11 +887,7 @@ def test_chat_function_calling(self, generative_models: generative_models):
861
887
assert "nice" in response2 .text
862
888
assert not response2 .candidates [0 ].function_calls
863
889
864
- @mock .patch .object (
865
- target = prediction_service .PredictionServiceClient ,
866
- attribute = "generate_content" ,
867
- new = mock_generate_content ,
868
- )
890
+ @patch_genai_services
869
891
@pytest .mark .parametrize (
870
892
"generative_models" ,
871
893
[generative_models , preview_generative_models ],
@@ -922,11 +944,7 @@ def test_chat_forced_function_calling(self, generative_models: generative_models
922
944
assert "nice" in response2 .text
923
945
assert not response2 .candidates [0 ].function_calls
924
946
925
- @mock .patch .object (
926
- target = prediction_service .PredictionServiceClient ,
927
- attribute = "generate_content" ,
928
- new = mock_generate_content ,
929
- )
947
+ @patch_genai_services
930
948
@pytest .mark .parametrize (
931
949
"generative_models" ,
932
950
[generative_models , preview_generative_models ],
@@ -982,11 +1000,7 @@ def test_conversion_methods(self, generative_models: generative_models):
982
1000
# Checking that the enums are serialized as strings, not integers.
983
1001
assert response .to_dict ()["candidates" ][0 ]["finish_reason" ] == "STOP"
984
1002
985
- @mock .patch .object (
986
- target = prediction_service .PredictionServiceClient ,
987
- attribute = "generate_content" ,
988
- new = mock_generate_content ,
989
- )
1003
+ @patch_genai_services
990
1004
def test_generate_content_grounding_google_search_retriever_preview (self ):
991
1005
model = preview_generative_models .GenerativeModel ("gemini-pro" )
992
1006
google_search_retriever_tool = (
@@ -999,11 +1013,7 @@ def test_generate_content_grounding_google_search_retriever_preview(self):
999
1013
)
1000
1014
assert response .text
1001
1015
1002
- @mock .patch .object (
1003
- target = prediction_service .PredictionServiceClient ,
1004
- attribute = "generate_content" ,
1005
- new = mock_generate_content ,
1006
- )
1016
+ @patch_genai_services
1007
1017
def test_generate_content_grounding_google_search_retriever (self ):
1008
1018
model = generative_models .GenerativeModel ("gemini-pro" )
1009
1019
google_search_retriever_tool = (
@@ -1016,11 +1026,7 @@ def test_generate_content_grounding_google_search_retriever(self):
1016
1026
)
1017
1027
assert response .text
1018
1028
1019
- @mock .patch .object (
1020
- target = prediction_service .PredictionServiceClient ,
1021
- attribute = "generate_content" ,
1022
- new = mock_generate_content ,
1023
- )
1029
+ @patch_genai_services
1024
1030
def test_generate_content_grounding_vertex_ai_search_retriever (self ):
1025
1031
model = preview_generative_models .GenerativeModel ("gemini-pro" )
1026
1032
vertex_ai_search_retriever_tool = preview_generative_models .Tool .from_retrieval (
@@ -1035,11 +1041,7 @@ def test_generate_content_grounding_vertex_ai_search_retriever(self):
1035
1041
)
1036
1042
assert response .text
1037
1043
1038
- @mock .patch .object (
1039
- target = prediction_service .PredictionServiceClient ,
1040
- attribute = "generate_content" ,
1041
- new = mock_generate_content ,
1042
- )
1044
+ @patch_genai_services
1043
1045
def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_location (
1044
1046
self ,
1045
1047
):
@@ -1058,11 +1060,7 @@ def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_
1058
1060
)
1059
1061
assert response .text
1060
1062
1061
- @mock .patch .object (
1062
- target = prediction_service .PredictionServiceClient ,
1063
- attribute = "generate_content" ,
1064
- new = mock_generate_content ,
1065
- )
1063
+ @patch_genai_services
1066
1064
def test_generate_content_vertex_rag_retriever (self ):
1067
1065
model = preview_generative_models .GenerativeModel ("gemini-pro" )
1068
1066
rag_resources = [
@@ -1085,11 +1083,7 @@ def test_generate_content_vertex_rag_retriever(self):
1085
1083
)
1086
1084
assert response .text
1087
1085
1088
- @mock .patch .object (
1089
- target = prediction_service .PredictionServiceClient ,
1090
- attribute = "generate_content" ,
1091
- new = mock_generate_content ,
1092
- )
1086
+ @patch_genai_services
1093
1087
def test_chat_automatic_function_calling_with_function_returning_dict (self ):
1094
1088
generative_models = preview_generative_models
1095
1089
get_current_weather_func = generative_models .FunctionDeclaration .from_func (
@@ -1124,11 +1118,7 @@ def test_chat_automatic_function_calling_with_function_returning_dict(self):
1124
1118
chat2 .send_message ("What is the weather like in Boston?" )
1125
1119
assert err .match ("Exceeded the maximum" )
1126
1120
1127
- @mock .patch .object (
1128
- target = prediction_service .PredictionServiceClient ,
1129
- attribute = "generate_content" ,
1130
- new = mock_generate_content ,
1131
- )
1121
+ @patch_genai_services
1132
1122
def test_chat_automatic_function_calling_with_function_returning_value (self ):
1133
1123
# Define a new function that returns a value instead of a dict.
1134
1124
def get_current_weather (location : str ):
0 commit comments