Skip to content

Commit 94f7cd9

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Added the GenerativeModel.start_chat(response_validation: bool = True) parameter
The error messages are now more informative. The use of the `raise_on_blocked` parameter has been deprecated. Use `response_validation` instead. PiperOrigin-RevId: 607188491
1 parent 0c3e294 commit 94f7cd9

File tree

4 files changed

+145
-36
lines changed

4 files changed

+145
-36
lines changed

tests/unit/vertexai/test_generative_models.py

+40
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,23 @@ def mock_generate_content(
120120
model: Optional[str] = None,
121121
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
122122
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
123+
last_message_part = request.contents[-1].parts[0]
124+
should_fail = last_message_part.text and "Please fail" in last_message_part.text
125+
if should_fail:
126+
response = gapic_prediction_service_types.GenerateContentResponse(
127+
candidates=[
128+
gapic_content_types.Candidate(
129+
finish_reason=gapic_content_types.Candidate.FinishReason.SAFETY,
130+
finish_message="Failed due to: " + last_message_part.text,
131+
safety_ratings=[
132+
gapic_content_types.SafetyRating(rating)
133+
for rating in _RESPONSE_SAFETY_RATINGS_STRUCT
134+
],
135+
),
136+
],
137+
)
138+
return response
139+
123140
is_continued_chat = len(request.contents) > 1
124141
has_retrieval = any(
125142
tool.retrieval or tool.google_search_retrieval for tool in request.tools
@@ -281,6 +298,29 @@ def test_chat_send_message(self, generative_models: generative_models):
281298
response2 = chat.send_message("Is sky blue on other planets?")
282299
assert response2.text
283300

301+
@mock.patch.object(
302+
target=prediction_service.PredictionServiceClient,
303+
attribute="generate_content",
304+
new=mock_generate_content,
305+
)
306+
@pytest.mark.parametrize(
307+
"generative_models",
308+
[generative_models, preview_generative_models],
309+
)
310+
def test_chat_send_message_response_validation_errors(
311+
self, generative_models: generative_models
312+
):
313+
model = generative_models.GenerativeModel("gemini-pro")
314+
chat = model.start_chat()
315+
response1 = chat.send_message("Why is sky blue?")
316+
assert response1.text
317+
assert len(chat.history) == 2
318+
319+
with pytest.raises(generative_models.ResponseValidationError):
320+
chat.send_message("Please fail!")
321+
# Checking that history did not get updated
322+
assert len(chat.history) == 2
323+
284324
@mock.patch.object(
285325
target=prediction_service.PredictionServiceClient,
286326
attribute="generate_content",

vertexai/generative_models/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
HarmBlockThreshold,
3030
Image,
3131
Part,
32-
ResponseBlockedError,
32+
ResponseValidationError,
3333
Tool,
3434
)
3535

@@ -46,6 +46,6 @@
4646
"HarmBlockThreshold",
4747
"Image",
4848
"Part",
49-
"ResponseBlockedError",
49+
"ResponseValidationError",
5050
"Tool",
5151
]

vertexai/generative_models/_generative_models.py

+96-33
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from vertexai.language_models import (
4646
_language_models as tunable_models,
4747
)
48+
import warnings
4849

4950
try:
5051
from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top
@@ -606,18 +607,28 @@ def start_chat(
606607
self,
607608
*,
608609
history: Optional[List["Content"]] = None,
610+
response_validation: bool = True,
609611
) -> "ChatSession":
610612
"""Creates a stateful chat session.
611613
612614
Args:
613615
history: Previous history to initialize the chat session.
616+
response_validation: Whether to validate responses before adding
617+
them to chat history. By default, `send_message` will raise
618+
error if the request or response is blocked or if the response
619+
is incomplete due to going over the max token limit.
620+
If set to `False`, the chat session history will always
621+
accumulate the request and response messages even if the
622+
reponse if blocked or incomplete. This can result in an unusable
623+
chat session state.
614624
615625
Returns:
616626
A ChatSession object.
617627
"""
618628
return ChatSession(
619629
model=self,
620630
history=history,
631+
response_validation=response_validation,
621632
)
622633

623634

@@ -628,6 +639,29 @@ def start_chat(
628639
]
629640

630641

642+
def _validate_response(
643+
response: "GenerationResponse",
644+
request_contents: Optional[List["Content"]] = None,
645+
response_chunks: Optional[List["GenerationResponse"]] = None,
646+
) -> None:
647+
candidate = response.candidates[0]
648+
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
649+
message = (
650+
"The model response did not completed successfully.\n"
651+
f"Finish reason: {candidate.finish_reason}.\n"
652+
f"Finish message: {candidate.finish_message}.\n"
653+
f"Safety ratings: {candidate.safety_ratings}.\n"
654+
"To protect the integrity of the chat session, the request and response were not added to chat history.\n"
655+
"To skip the response validation, specify `model.start_chat(response_validation=False)`.\n"
656+
"Note that letting blocked or otherwise incomplete responses into chat history might lead to future interactions being blocked by the service."
657+
)
658+
raise ResponseValidationError(
659+
message=message,
660+
request_contents=request_contents,
661+
responses=response_chunks,
662+
)
663+
664+
631665
class ChatSession:
632666
"""Chat session holds the chat history."""
633667

@@ -639,15 +673,15 @@ def __init__(
639673
model: _GenerativeModel,
640674
*,
641675
history: Optional[List["Content"]] = None,
642-
raise_on_blocked: bool = True,
676+
response_validation: bool = True,
643677
):
644678
if history:
645679
if not all(isinstance(item, Content) for item in history):
646680
raise ValueError("history must be a list of Content objects.")
647681

648682
self._model = model
649683
self._history = history or []
650-
self._raise_on_blocked = raise_on_blocked
684+
self._response_validator = _validate_response if response_validation else None
651685

652686
@property
653687
def history(self) -> List["Content"]:
@@ -784,13 +818,12 @@ def _send_message(
784818
tools=tools,
785819
)
786820
# By default we're not adding incomplete interactions to history.
787-
if self._raise_on_blocked:
788-
if response.candidates[0].finish_reason not in _SUCCESSFUL_FINISH_REASONS:
789-
raise ResponseBlockedError(
790-
message="The response was blocked.",
791-
request_contents=request_history,
792-
responses=[response],
793-
)
821+
if self._response_validator is not None:
822+
self._response_validator(
823+
response=response,
824+
request_contents=request_history,
825+
response_chunks=[response],
826+
)
794827

795828
# Adding the request and the first response candidate to history
796829
response_message = response.candidates[0].content
@@ -841,13 +874,13 @@ async def _send_message_async(
841874
tools=tools,
842875
)
843876
# By default we're not adding incomplete interactions to history.
844-
if self._raise_on_blocked:
845-
if response.candidates[0].finish_reason not in _SUCCESSFUL_FINISH_REASONS:
846-
raise ResponseBlockedError(
847-
message="The response was blocked.",
848-
request_contents=request_history,
849-
responses=[response],
850-
)
877+
if self._response_validator is not None:
878+
self._response_validator(
879+
response=response,
880+
request_contents=request_history,
881+
response_chunks=[response],
882+
)
883+
851884
# Adding the request and the first response candidate to history
852885
response_message = response.candidates[0].content
853886
# Response role is NOT set by the model.
@@ -905,13 +938,12 @@ def _send_message_streaming(
905938
else:
906939
full_response = chunk
907940
# By default we're not adding incomplete interactions to history.
908-
if self._raise_on_blocked:
909-
if chunk.candidates[0].finish_reason not in _SUCCESSFUL_FINISH_REASONS:
910-
raise ResponseBlockedError(
911-
message="The response was blocked.",
912-
request_contents=request_history,
913-
responses=chunks,
914-
)
941+
if self._response_validator is not None:
942+
self._response_validator(
943+
response=chunk,
944+
request_contents=request_history,
945+
response_chunks=chunks,
946+
)
915947
yield chunk
916948
if not full_response:
917949
return
@@ -973,16 +1005,13 @@ async def async_generator():
9731005
else:
9741006
full_response = chunk
9751007
# By default we're not adding incomplete interactions to history.
976-
if self._raise_on_blocked:
977-
if (
978-
chunk.candidates[0].finish_reason
979-
not in _SUCCESSFUL_FINISH_REASONS
980-
):
981-
raise ResponseBlockedError(
982-
message="The response was blocked.",
983-
request_contents=request_history,
984-
responses=chunks,
985-
)
1008+
if self._response_validator is not None:
1009+
self._response_validator(
1010+
response=chunk,
1011+
request_contents=request_history,
1012+
response_chunks=chunks,
1013+
)
1014+
9861015
yield chunk
9871016
if not full_response:
9881017
return
@@ -996,6 +1025,36 @@ async def async_generator():
9961025
return async_generator()
9971026

9981027

1028+
class _PreviewChatSession(ChatSession):
1029+
__doc__ = ChatSession.__doc__
1030+
1031+
# This class preserves backwards compatibility with the `raise_on_blocked` parameter.
1032+
1033+
def __init__(
1034+
self,
1035+
model: _GenerativeModel,
1036+
*,
1037+
history: Optional[List["Content"]] = None,
1038+
response_validation: bool = True,
1039+
# Deprecated
1040+
raise_on_blocked: Optional[bool] = None,
1041+
):
1042+
if raise_on_blocked is not None:
1043+
warnings.warn(
1044+
message="Use `response_validation` instead of `raise_on_blocked`."
1045+
)
1046+
if response_validation is not None:
1047+
raise ValueError(
1048+
"Cannot use `response_validation` when `raise_on_blocked` is set."
1049+
)
1050+
response_validation = raise_on_blocked
1051+
super().__init__(
1052+
model=model,
1053+
history=history,
1054+
response_validation=response_validation,
1055+
)
1056+
1057+
9991058
class ResponseBlockedError(Exception):
10001059
def __init__(
10011060
self,
@@ -1008,6 +1067,10 @@ def __init__(
10081067
self.responses = responses
10091068

10101069

1070+
class ResponseValidationError(ResponseBlockedError):
1071+
pass
1072+
1073+
10111074
### Structures
10121075

10131076

vertexai/preview/generative_models.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from vertexai.generative_models._generative_models import (
2020
grounding,
2121
_PreviewGenerativeModel,
22+
_PreviewChatSession,
2223
GenerationConfig,
2324
GenerationResponse,
2425
Candidate,
25-
ChatSession,
2626
Content,
2727
FinishReason,
2828
FunctionDeclaration,
@@ -31,6 +31,7 @@
3131
Image,
3232
Part,
3333
ResponseBlockedError,
34+
ResponseValidationError,
3435
Tool,
3536
)
3637

@@ -39,6 +40,10 @@ class GenerativeModel(_PreviewGenerativeModel):
3940
__doc__ = _PreviewGenerativeModel.__doc__
4041

4142

43+
class ChatSession(_PreviewChatSession):
44+
__doc__ = _PreviewChatSession.__doc__
45+
46+
4247
__all__ = [
4348
"grounding",
4449
"GenerationConfig",
@@ -54,5 +59,6 @@ class GenerativeModel(_PreviewGenerativeModel):
5459
"Image",
5560
"Part",
5661
"ResponseBlockedError",
62+
"ResponseValidationError",
5763
"Tool",
5864
]

0 commit comments

Comments
 (0)