@@ -88,6 +88,13 @@ def _model_resource_name(self) -> str:
88
88
return self ._endpoint .list_models ()[0 ].model
89
89
90
90
91
+ @dataclasses .dataclass
92
+ class _PredictionRequest :
93
+ """A single-instance prediction request."""
94
+ instance : Dict [str , Any ]
95
+ parameters : Optional [Dict [str , Any ]] = None
96
+
97
+
91
98
class _TunableModelMixin (_LanguageModel ):
92
99
"""Model that can be tuned."""
93
100
@@ -915,16 +922,16 @@ def message_history(self) -> List[ChatMessage]:
915
922
"""List of previous messages."""
916
923
return self ._message_history
917
924
918
- def send_message (
925
+ def _prepare_request (
919
926
self ,
920
927
message : str ,
921
928
* ,
922
929
max_output_tokens : Optional [int ] = None ,
923
930
temperature : Optional [float ] = None ,
924
931
top_k : Optional [int ] = None ,
925
932
top_p : Optional [float ] = None ,
926
- ) -> "TextGenerationResponse" :
927
- """Sends message to the language model and gets a response .
933
+ ) -> _PredictionRequest :
934
+ """Prepares a request for the language model.
928
935
929
936
Args:
930
937
message: Message to send to the model
@@ -938,7 +945,7 @@ def send_message(
938
945
Uses the value specified when calling `ChatModel.start_chat` by default.
939
946
940
947
Returns:
941
- A `TextGenerationResponse ` object that contains the text produced by the model .
948
+ A `_PredictionRequest ` object.
942
949
"""
943
950
prediction_parameters = {}
944
951
@@ -986,27 +993,87 @@ def send_message(
986
993
for example in self ._examples
987
994
]
988
995
989
- prediction_response = self . _model . _endpoint . predict (
990
- instances = [ prediction_instance ] ,
996
+ return _PredictionRequest (
997
+ instance = prediction_instance ,
991
998
parameters = prediction_parameters ,
992
999
)
993
1000
994
- prediction = prediction_response .predictions [0 ]
1001
+ @classmethod
1002
+ def _parse_chat_prediction_response (
1003
+ cls ,
1004
+ prediction_response : aiplatform .models .Prediction ,
1005
+ prediction_idx : int = 0 ,
1006
+ candidate_idx : int = 0 ,
1007
+ ) -> TextGenerationResponse :
1008
+ """Parses prediction response for chat models.
1009
+
1010
+ Args:
1011
+ prediction_response: Prediction response received from the model
1012
+ prediction_idx: Index of the prediction to parse.
1013
+ candidate_idx: Index of the candidate to parse.
1014
+
1015
+ Returns:
1016
+ A `TextGenerationResponse` object.
1017
+ """
1018
+ prediction = prediction_response .predictions [prediction_idx ]
995
1019
# ! Note: For chat models, the safetyAttributes is a list.
996
- safety_attributes = prediction ["safetyAttributes" ][0 ]
997
- response_obj = TextGenerationResponse (
998
- text = prediction ["candidates" ][0 ]["content" ]
1020
+ safety_attributes = prediction ["safetyAttributes" ][candidate_idx ]
1021
+ return TextGenerationResponse (
1022
+ text = prediction ["candidates" ][candidate_idx ]["content" ]
999
1023
if prediction .get ("candidates" )
1000
1024
else None ,
1001
1025
_prediction_response = prediction_response ,
1002
1026
is_blocked = safety_attributes .get ("blocked" , False ),
1003
1027
safety_attributes = dict (
1004
1028
zip (
1005
- safety_attributes .get ("categories" , []),
1006
- safety_attributes .get ("scores" , []),
1029
+ # Unlike with normal prediction, in streaming prediction
1030
+ # categories and scores can be None
1031
+ safety_attributes .get ("categories" ) or [],
1032
+ safety_attributes .get ("scores" ) or [],
1007
1033
)
1008
1034
),
1009
1035
)
1036
+
1037
+ def send_message (
1038
+ self ,
1039
+ message : str ,
1040
+ * ,
1041
+ max_output_tokens : Optional [int ] = None ,
1042
+ temperature : Optional [float ] = None ,
1043
+ top_k : Optional [int ] = None ,
1044
+ top_p : Optional [float ] = None ,
1045
+ ) -> "TextGenerationResponse" :
1046
+ """Sends message to the language model and gets a response.
1047
+
1048
+ Args:
1049
+ message: Message to send to the model
1050
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1051
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1052
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1053
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1054
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1055
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1056
+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1057
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1058
+
1059
+ Returns:
1060
+ A `TextGenerationResponse` object that contains the text produced by the model.
1061
+ """
1062
+ prediction_request = self ._prepare_request (
1063
+ message = message ,
1064
+ max_output_tokens = max_output_tokens ,
1065
+ temperature = temperature ,
1066
+ top_k = top_k ,
1067
+ top_p = top_p ,
1068
+ )
1069
+
1070
+ prediction_response = self ._model ._endpoint .predict (
1071
+ instances = [prediction_request .instance ],
1072
+ parameters = prediction_request .parameters ,
1073
+ )
1074
+ response_obj = self ._parse_chat_prediction_response (
1075
+ prediction_response = prediction_response
1076
+ )
1010
1077
response_text = response_obj .text
1011
1078
1012
1079
self ._message_history .append (
@@ -1018,6 +1085,71 @@ def send_message(
1018
1085
1019
1086
return response_obj
1020
1087
1088
+ def send_message_streaming (
1089
+ self ,
1090
+ message : str ,
1091
+ * ,
1092
+ max_output_tokens : Optional [int ] = None ,
1093
+ temperature : Optional [float ] = None ,
1094
+ top_k : Optional [int ] = None ,
1095
+ top_p : Optional [float ] = None ,
1096
+ ) -> Iterator [TextGenerationResponse ]:
1097
+ """Sends message to the language model and gets a streamed response.
1098
+
1099
+ The response is only added to the history once it's fully read.
1100
+
1101
+ Args:
1102
+ message: Message to send to the model
1103
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1104
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1105
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1106
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1107
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1108
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1109
+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1110
+ Uses the value specified when calling `ChatModel.start_chat` by default.
1111
+
1112
+ Yields:
1113
+ A stream of `TextGenerationResponse` objects that contain partial
1114
+ responses produced by the model.
1115
+ """
1116
+ prediction_request = self ._prepare_request (
1117
+ message = message ,
1118
+ max_output_tokens = max_output_tokens ,
1119
+ temperature = temperature ,
1120
+ top_k = top_k ,
1121
+ top_p = top_p ,
1122
+ )
1123
+
1124
+ prediction_service_client = self ._model ._endpoint ._prediction_client
1125
+
1126
+ full_response_text = ""
1127
+
1128
+ for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
1129
+ prediction_service_client = prediction_service_client ,
1130
+ endpoint_name = self ._model ._endpoint_name ,
1131
+ instance = prediction_request .instance ,
1132
+ parameters = prediction_request .parameters ,
1133
+ ):
1134
+ prediction_response = aiplatform .models .Prediction (
1135
+ predictions = [prediction_dict ],
1136
+ deployed_model_id = "" ,
1137
+ )
1138
+ text_generation_response = self ._parse_chat_prediction_response (
1139
+ prediction_response = prediction_response
1140
+ )
1141
+ full_response_text += text_generation_response .text
1142
+ yield text_generation_response
1143
+
1144
+ # We only add the question and answer to the history if/when the answer
1145
+ # was read fully. Otherwise, the answer would have been truncated.
1146
+ self ._message_history .append (
1147
+ ChatMessage (content = message , author = self .USER_AUTHOR )
1148
+ )
1149
+ self ._message_history .append (
1150
+ ChatMessage (content = full_response_text , author = self .MODEL_AUTHOR )
1151
+ )
1152
+
1021
1153
1022
1154
class ChatSession (_ChatSessionBase ):
1023
1155
"""ChatSession represents a chat session with a language model.
0 commit comments