45
45
from vertexai .language_models import (
46
46
_language_models as tunable_models ,
47
47
)
48
+ import warnings
48
49
49
50
try :
50
51
from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top
@@ -606,18 +607,28 @@ def start_chat(
606
607
self ,
607
608
* ,
608
609
history : Optional [List ["Content" ]] = None ,
610
+ response_validation : bool = True ,
609
611
) -> "ChatSession" :
610
612
"""Creates a stateful chat session.
611
613
612
614
Args:
613
615
history: Previous history to initialize the chat session.
616
+ response_validation: Whether to validate responses before adding
617
+ them to chat history. By default, `send_message` will raise
618
+ error if the request or response is blocked or if the response
619
+ is incomplete due to going over the max token limit.
620
+ If set to `False`, the chat session history will always
621
+ accumulate the request and response messages even if the
622
+ reponse if blocked or incomplete. This can result in an unusable
623
+ chat session state.
614
624
615
625
Returns:
616
626
A ChatSession object.
617
627
"""
618
628
return ChatSession (
619
629
model = self ,
620
630
history = history ,
631
+ response_validation = response_validation ,
621
632
)
622
633
623
634
@@ -628,6 +639,29 @@ def start_chat(
628
639
]
629
640
630
641
642
+ def _validate_response (
643
+ response : "GenerationResponse" ,
644
+ request_contents : Optional [List ["Content" ]] = None ,
645
+ response_chunks : Optional [List ["GenerationResponse" ]] = None ,
646
+ ) -> None :
647
+ candidate = response .candidates [0 ]
648
+ if candidate .finish_reason not in _SUCCESSFUL_FINISH_REASONS :
649
+ message = (
650
+ "The model response did not completed successfully.\n "
651
+ f"Finish reason: { candidate .finish_reason } .\n "
652
+ f"Finish message: { candidate .finish_message } .\n "
653
+ f"Safety ratings: { candidate .safety_ratings } .\n "
654
+ "To protect the integrity of the chat session, the request and response were not added to chat history.\n "
655
+ "To skip the response validation, specify `model.start_chat(response_validation=False)`.\n "
656
+ "Note that letting blocked or otherwise incomplete responses into chat history might lead to future interactions being blocked by the service."
657
+ )
658
+ raise ResponseValidationError (
659
+ message = message ,
660
+ request_contents = request_contents ,
661
+ responses = response_chunks ,
662
+ )
663
+
664
+
631
665
class ChatSession :
632
666
"""Chat session holds the chat history."""
633
667
@@ -639,15 +673,15 @@ def __init__(
639
673
model : _GenerativeModel ,
640
674
* ,
641
675
history : Optional [List ["Content" ]] = None ,
642
- raise_on_blocked : bool = True ,
676
+ response_validation : bool = True ,
643
677
):
644
678
if history :
645
679
if not all (isinstance (item , Content ) for item in history ):
646
680
raise ValueError ("history must be a list of Content objects." )
647
681
648
682
self ._model = model
649
683
self ._history = history or []
650
- self ._raise_on_blocked = raise_on_blocked
684
+ self ._response_validator = _validate_response if response_validation else None
651
685
652
686
@property
653
687
def history (self ) -> List ["Content" ]:
@@ -784,13 +818,12 @@ def _send_message(
784
818
tools = tools ,
785
819
)
786
820
# By default we're not adding incomplete interactions to history.
787
- if self ._raise_on_blocked :
788
- if response .candidates [0 ].finish_reason not in _SUCCESSFUL_FINISH_REASONS :
789
- raise ResponseBlockedError (
790
- message = "The response was blocked." ,
791
- request_contents = request_history ,
792
- responses = [response ],
793
- )
821
+ if self ._response_validator is not None :
822
+ self ._response_validator (
823
+ response = response ,
824
+ request_contents = request_history ,
825
+ response_chunks = [response ],
826
+ )
794
827
795
828
# Adding the request and the first response candidate to history
796
829
response_message = response .candidates [0 ].content
@@ -841,13 +874,13 @@ async def _send_message_async(
841
874
tools = tools ,
842
875
)
843
876
# By default we're not adding incomplete interactions to history.
844
- if self ._raise_on_blocked :
845
- if response . candidates [ 0 ]. finish_reason not in _SUCCESSFUL_FINISH_REASONS :
846
- raise ResponseBlockedError (
847
- message = "The response was blocked." ,
848
- request_contents = request_history ,
849
- responses = [ response ],
850
- )
877
+ if self ._response_validator is not None :
878
+ self . _response_validator (
879
+ response = response ,
880
+ request_contents = request_history ,
881
+ response_chunks = [ response ] ,
882
+ )
883
+
851
884
# Adding the request and the first response candidate to history
852
885
response_message = response .candidates [0 ].content
853
886
# Response role is NOT set by the model.
@@ -905,13 +938,12 @@ def _send_message_streaming(
905
938
else :
906
939
full_response = chunk
907
940
# By default we're not adding incomplete interactions to history.
908
- if self ._raise_on_blocked :
909
- if chunk .candidates [0 ].finish_reason not in _SUCCESSFUL_FINISH_REASONS :
910
- raise ResponseBlockedError (
911
- message = "The response was blocked." ,
912
- request_contents = request_history ,
913
- responses = chunks ,
914
- )
941
+ if self ._response_validator is not None :
942
+ self ._response_validator (
943
+ response = chunk ,
944
+ request_contents = request_history ,
945
+ response_chunks = chunks ,
946
+ )
915
947
yield chunk
916
948
if not full_response :
917
949
return
@@ -973,16 +1005,13 @@ async def async_generator():
973
1005
else :
974
1006
full_response = chunk
975
1007
# By default we're not adding incomplete interactions to history.
976
- if self ._raise_on_blocked :
977
- if (
978
- chunk .candidates [0 ].finish_reason
979
- not in _SUCCESSFUL_FINISH_REASONS
980
- ):
981
- raise ResponseBlockedError (
982
- message = "The response was blocked." ,
983
- request_contents = request_history ,
984
- responses = chunks ,
985
- )
1008
+ if self ._response_validator is not None :
1009
+ self ._response_validator (
1010
+ response = chunk ,
1011
+ request_contents = request_history ,
1012
+ response_chunks = chunks ,
1013
+ )
1014
+
986
1015
yield chunk
987
1016
if not full_response :
988
1017
return
@@ -996,6 +1025,36 @@ async def async_generator():
996
1025
return async_generator ()
997
1026
998
1027
1028
+ class _PreviewChatSession (ChatSession ):
1029
+ __doc__ = ChatSession .__doc__
1030
+
1031
+ # This class preserves backwards compatibility with the `raise_on_blocked` parameter.
1032
+
1033
+ def __init__ (
1034
+ self ,
1035
+ model : _GenerativeModel ,
1036
+ * ,
1037
+ history : Optional [List ["Content" ]] = None ,
1038
+ response_validation : bool = True ,
1039
+ # Deprecated
1040
+ raise_on_blocked : Optional [bool ] = None ,
1041
+ ):
1042
+ if raise_on_blocked is not None :
1043
+ warnings .warn (
1044
+ message = "Use `response_validation` instead of `raise_on_blocked`."
1045
+ )
1046
+ if response_validation is not None :
1047
+ raise ValueError (
1048
+ "Cannot use `response_validation` when `raise_on_blocked` is set."
1049
+ )
1050
+ response_validation = raise_on_blocked
1051
+ super ().__init__ (
1052
+ model = model ,
1053
+ history = history ,
1054
+ response_validation = response_validation ,
1055
+ )
1056
+
1057
+
999
1058
class ResponseBlockedError (Exception ):
1000
1059
def __init__ (
1001
1060
self ,
@@ -1008,6 +1067,10 @@ def __init__(
1008
1067
self .responses = responses
1009
1068
1010
1069
1070
+ class ResponseValidationError (ResponseBlockedError ):
1071
+ pass
1072
+
1073
+
1011
1074
### Structures
1012
1075
1013
1076
0 commit comments