Skip to content

Commit 66d84af

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Switched the GA version of the generative_models classes to use the v1 service APIs instead of v1beta1
PiperOrigin-RevId: 675454152
1 parent f78b953 commit 66d84af

File tree

3 files changed

+329
-132
lines changed

3 files changed

+329
-132
lines changed

tests/unit/vertexai/test_generative_models.py

+90-100
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323

2424
import vertexai
2525
from google.cloud.aiplatform import initializer
26+
from google.cloud.aiplatform_v1 import types as types_v1
27+
from google.cloud.aiplatform_v1.services import (
28+
prediction_service as prediction_service_v1,
29+
)
30+
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
2631
from vertexai import generative_models
2732
from vertexai.preview import (
2833
generative_models as preview_generative_models,
@@ -326,6 +331,72 @@ def mock_stream_generate_content(
326331
yield blocked_chunk
327332

328333

334+
def mock_generate_content_v1(
335+
self,
336+
request: types_v1.GenerateContentRequest,
337+
*,
338+
model: Optional[str] = None,
339+
contents: Optional[MutableSequence[types_v1.Content]] = None,
340+
) -> types_v1.GenerateContentResponse:
341+
request_v1beta1 = types_v1beta1.GenerateContentRequest.deserialize(
342+
type(request).serialize(request)
343+
)
344+
response_v1beta1 = mock_generate_content(
345+
self=self,
346+
request=request_v1beta1,
347+
)
348+
response_v1 = types_v1.GenerateContentResponse.deserialize(
349+
type(response_v1beta1).serialize(response_v1beta1)
350+
)
351+
return response_v1
352+
353+
354+
def mock_stream_generate_content_v1(
355+
self,
356+
request: types_v1.GenerateContentRequest,
357+
*,
358+
model: Optional[str] = None,
359+
contents: Optional[MutableSequence[types_v1.Content]] = None,
360+
) -> Iterable[types_v1.GenerateContentResponse]:
361+
request_v1beta1 = types_v1beta1.GenerateContentRequest.deserialize(
362+
type(request).serialize(request)
363+
)
364+
for response_v1beta1 in mock_stream_generate_content(
365+
self=self,
366+
request=request_v1beta1,
367+
):
368+
response_v1 = types_v1.GenerateContentResponse.deserialize(
369+
type(response_v1beta1).serialize(response_v1beta1)
370+
)
371+
yield response_v1
372+
373+
374+
def patch_genai_services(func: callable):
375+
"""Patches GenAI services (v1 and v1beta1, streaming and non-streaming)."""
376+
377+
func = mock.patch.object(
378+
target=prediction_service.PredictionServiceClient,
379+
attribute="generate_content",
380+
new=mock_generate_content,
381+
)(func)
382+
func = mock.patch.object(
383+
target=prediction_service_v1.PredictionServiceClient,
384+
attribute="generate_content",
385+
new=mock_generate_content_v1,
386+
)(func)
387+
func = mock.patch.object(
388+
target=prediction_service.PredictionServiceClient,
389+
attribute="stream_generate_content",
390+
new=mock_stream_generate_content,
391+
)(func)
392+
func = mock.patch.object(
393+
target=prediction_service_v1.PredictionServiceClient,
394+
attribute="stream_generate_content",
395+
new=mock_stream_generate_content_v1,
396+
)(func)
397+
return func
398+
399+
329400
@pytest.fixture
330401
def mock_get_cached_content_fixture():
331402
"""Mocks GenAiCacheServiceClient.get_cached_content()."""
@@ -376,11 +447,6 @@ def setup_method(self):
376447
def teardown_method(self):
377448
initializer.global_pool.shutdown(wait=True)
378449

379-
@mock.patch.object(
380-
target=prediction_service.PredictionServiceClient,
381-
attribute="generate_content",
382-
new=mock_generate_content,
383-
)
384450
@pytest.mark.parametrize(
385451
"generative_models",
386452
[generative_models, preview_generative_models],
@@ -489,11 +555,7 @@ def test_generative_model_from_cached_content_with_resource_name(
489555
== "cached-content-id-in-from-cached-content-test"
490556
)
491557

492-
@mock.patch.object(
493-
target=prediction_service.PredictionServiceClient,
494-
attribute="generate_content",
495-
new=mock_generate_content,
496-
)
558+
@patch_genai_services
497559
@pytest.mark.parametrize(
498560
"generative_models",
499561
[generative_models, preview_generative_models],
@@ -601,11 +663,7 @@ def test_generate_content_with_cached_content(
601663

602664
assert response.text == "response to " + cached_content.resource_name
603665

604-
@mock.patch.object(
605-
target=prediction_service.PredictionServiceClient,
606-
attribute="stream_generate_content",
607-
new=mock_stream_generate_content,
608-
)
666+
@patch_genai_services
609667
@pytest.mark.parametrize(
610668
"generative_models",
611669
[generative_models, preview_generative_models],
@@ -616,11 +674,7 @@ def test_generate_content_streaming(self, generative_models: generative_models):
616674
for chunk in stream:
617675
assert chunk.text
618676

619-
@mock.patch.object(
620-
target=prediction_service.PredictionServiceClient,
621-
attribute="generate_content",
622-
new=mock_generate_content,
623-
)
677+
@patch_genai_services
624678
@pytest.mark.parametrize(
625679
"generative_models",
626680
[generative_models, preview_generative_models],
@@ -668,11 +722,7 @@ def test_generate_content_response_accessor_errors(
668722
assert e.match("no text")
669723
assert e.match("function_call")
670724

671-
@mock.patch.object(
672-
target=prediction_service.PredictionServiceClient,
673-
attribute="generate_content",
674-
new=mock_generate_content,
675-
)
725+
@patch_genai_services
676726
@pytest.mark.parametrize(
677727
"generative_models",
678728
[generative_models, preview_generative_models],
@@ -685,11 +735,7 @@ def test_chat_send_message(self, generative_models: generative_models):
685735
response2 = chat.send_message("Is sky blue on other planets?")
686736
assert response2.text
687737

688-
@mock.patch.object(
689-
target=prediction_service.PredictionServiceClient,
690-
attribute="stream_generate_content",
691-
new=mock_stream_generate_content,
692-
)
738+
@patch_genai_services
693739
@pytest.mark.parametrize(
694740
"generative_models",
695741
[generative_models, preview_generative_models],
@@ -704,11 +750,7 @@ def test_chat_send_message_streaming(self, generative_models: generative_models)
704750
for chunk in stream2:
705751
assert chunk.candidates
706752

707-
@mock.patch.object(
708-
target=prediction_service.PredictionServiceClient,
709-
attribute="generate_content",
710-
new=mock_generate_content,
711-
)
753+
@patch_genai_services
712754
@pytest.mark.parametrize(
713755
"generative_models",
714756
[generative_models, preview_generative_models],
@@ -727,11 +769,7 @@ def test_chat_send_message_response_validation_errors(
727769
# Checking that history did not get updated
728770
assert len(chat.history) == 2
729771

730-
@mock.patch.object(
731-
target=prediction_service.PredictionServiceClient,
732-
attribute="generate_content",
733-
new=mock_generate_content,
734-
)
772+
@patch_genai_services
735773
@pytest.mark.parametrize(
736774
"generative_models",
737775
[generative_models, preview_generative_models],
@@ -754,11 +792,7 @@ def test_chat_send_message_response_blocked_errors(
754792
# Checking that history did not get updated
755793
assert len(chat.history) == 2
756794

757-
@mock.patch.object(
758-
target=prediction_service.PredictionServiceClient,
759-
attribute="generate_content",
760-
new=mock_generate_content,
761-
)
795+
@patch_genai_services
762796
@pytest.mark.parametrize(
763797
"generative_models",
764798
[generative_models, preview_generative_models],
@@ -775,11 +809,7 @@ def test_chat_send_message_response_candidate_blocked_error(
775809
# Checking that history did not get updated
776810
assert not chat.history
777811

778-
@mock.patch.object(
779-
target=prediction_service.PredictionServiceClient,
780-
attribute="generate_content",
781-
new=mock_generate_content,
782-
)
812+
@patch_genai_services
783813
@pytest.mark.parametrize(
784814
"generative_models",
785815
[generative_models, preview_generative_models],
@@ -808,11 +838,7 @@ def test_finish_reason_max_tokens_in_generate_content_and_send_message(
808838
# Verify that history did not get updated
809839
assert not chat.history
810840

811-
@mock.patch.object(
812-
target=prediction_service.PredictionServiceClient,
813-
attribute="generate_content",
814-
new=mock_generate_content,
815-
)
841+
@patch_genai_services
816842
@pytest.mark.parametrize(
817843
"generative_models",
818844
[generative_models, preview_generative_models],
@@ -861,11 +887,7 @@ def test_chat_function_calling(self, generative_models: generative_models):
861887
assert "nice" in response2.text
862888
assert not response2.candidates[0].function_calls
863889

864-
@mock.patch.object(
865-
target=prediction_service.PredictionServiceClient,
866-
attribute="generate_content",
867-
new=mock_generate_content,
868-
)
890+
@patch_genai_services
869891
@pytest.mark.parametrize(
870892
"generative_models",
871893
[generative_models, preview_generative_models],
@@ -922,11 +944,7 @@ def test_chat_forced_function_calling(self, generative_models: generative_models
922944
assert "nice" in response2.text
923945
assert not response2.candidates[0].function_calls
924946

925-
@mock.patch.object(
926-
target=prediction_service.PredictionServiceClient,
927-
attribute="generate_content",
928-
new=mock_generate_content,
929-
)
947+
@patch_genai_services
930948
@pytest.mark.parametrize(
931949
"generative_models",
932950
[generative_models, preview_generative_models],
@@ -982,11 +1000,7 @@ def test_conversion_methods(self, generative_models: generative_models):
9821000
# Checking that the enums are serialized as strings, not integers.
9831001
assert response.to_dict()["candidates"][0]["finish_reason"] == "STOP"
9841002

985-
@mock.patch.object(
986-
target=prediction_service.PredictionServiceClient,
987-
attribute="generate_content",
988-
new=mock_generate_content,
989-
)
1003+
@patch_genai_services
9901004
def test_generate_content_grounding_google_search_retriever_preview(self):
9911005
model = preview_generative_models.GenerativeModel("gemini-pro")
9921006
google_search_retriever_tool = (
@@ -999,11 +1013,7 @@ def test_generate_content_grounding_google_search_retriever_preview(self):
9991013
)
10001014
assert response.text
10011015

1002-
@mock.patch.object(
1003-
target=prediction_service.PredictionServiceClient,
1004-
attribute="generate_content",
1005-
new=mock_generate_content,
1006-
)
1016+
@patch_genai_services
10071017
def test_generate_content_grounding_google_search_retriever(self):
10081018
model = generative_models.GenerativeModel("gemini-pro")
10091019
google_search_retriever_tool = (
@@ -1016,11 +1026,7 @@ def test_generate_content_grounding_google_search_retriever(self):
10161026
)
10171027
assert response.text
10181028

1019-
@mock.patch.object(
1020-
target=prediction_service.PredictionServiceClient,
1021-
attribute="generate_content",
1022-
new=mock_generate_content,
1023-
)
1029+
@patch_genai_services
10241030
def test_generate_content_grounding_vertex_ai_search_retriever(self):
10251031
model = preview_generative_models.GenerativeModel("gemini-pro")
10261032
vertex_ai_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
@@ -1035,11 +1041,7 @@ def test_generate_content_grounding_vertex_ai_search_retriever(self):
10351041
)
10361042
assert response.text
10371043

1038-
@mock.patch.object(
1039-
target=prediction_service.PredictionServiceClient,
1040-
attribute="generate_content",
1041-
new=mock_generate_content,
1042-
)
1044+
@patch_genai_services
10431045
def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_location(
10441046
self,
10451047
):
@@ -1058,11 +1060,7 @@ def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_
10581060
)
10591061
assert response.text
10601062

1061-
@mock.patch.object(
1062-
target=prediction_service.PredictionServiceClient,
1063-
attribute="generate_content",
1064-
new=mock_generate_content,
1065-
)
1063+
@patch_genai_services
10661064
def test_generate_content_vertex_rag_retriever(self):
10671065
model = preview_generative_models.GenerativeModel("gemini-pro")
10681066
rag_resources = [
@@ -1085,11 +1083,7 @@ def test_generate_content_vertex_rag_retriever(self):
10851083
)
10861084
assert response.text
10871085

1088-
@mock.patch.object(
1089-
target=prediction_service.PredictionServiceClient,
1090-
attribute="generate_content",
1091-
new=mock_generate_content,
1092-
)
1086+
@patch_genai_services
10931087
def test_chat_automatic_function_calling_with_function_returning_dict(self):
10941088
generative_models = preview_generative_models
10951089
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
@@ -1124,11 +1118,7 @@ def test_chat_automatic_function_calling_with_function_returning_dict(self):
11241118
chat2.send_message("What is the weather like in Boston?")
11251119
assert err.match("Exceeded the maximum")
11261120

1127-
@mock.patch.object(
1128-
target=prediction_service.PredictionServiceClient,
1129-
attribute="generate_content",
1130-
new=mock_generate_content,
1131-
)
1121+
@patch_genai_services
11321122
def test_chat_automatic_function_calling_with_function_returning_value(self):
11331123
# Define a new function that returns a value instead of a dict.
11341124
def get_current_weather(location: str):

tests/unit/vertexai/test_prompts.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
# TODO(b/360932655): Use mock_generate_content from test_generative_models
3737
from vertexai.preview import rag
3838
from vertexai.generative_models._generative_models import (
39-
prediction_service,
40-
gapic_prediction_service_types,
41-
gapic_content_types,
42-
gapic_tool_types,
39+
prediction_service_v1 as prediction_service,
40+
types_v1 as gapic_prediction_service_types,
41+
types_v1 as gapic_content_types,
42+
types_v1 as gapic_tool_types,
4343
)
4444

4545

0 commit comments

Comments
 (0)