Skip to content

Commit ce60cf7

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support streaming prediction for chat models
PiperOrigin-RevId: 558246099
1 parent fb527f3 commit ce60cf7

File tree

3 files changed

+289
-12
lines changed

3 files changed

+289
-12
lines changed

tests/system/aiplatform/test_language_models.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,44 @@ def test_chat_on_chat_model(self):
100100
assert chat.message_history[2].content == message2
101101
assert chat.message_history[3].author == chat.MODEL_AUTHOR
102102

103+
def test_chat_model_send_message_streaming(self):
104+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
105+
106+
chat_model = ChatModel.from_pretrained("google/chat-bison@001")
107+
chat = chat_model.start_chat(
108+
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
109+
examples=[
110+
InputOutputTextPair(
111+
input_text="Who do you work for?",
112+
output_text="I work for Ned.",
113+
),
114+
InputOutputTextPair(
115+
input_text="What do I like?",
116+
output_text="Ned likes watching movies.",
117+
),
118+
],
119+
temperature=0.0,
120+
)
121+
122+
message1 = "Are my favorite movies based on a book series?"
123+
for response in chat.send_message_streaming(message1):
124+
assert response.text
125+
assert len(chat.message_history) == 2
126+
assert chat.message_history[0].author == chat.USER_AUTHOR
127+
assert chat.message_history[0].content == message1
128+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
129+
130+
message2 = "When were these books published?"
131+
for response2 in chat.send_message_streaming(
132+
message2,
133+
temperature=0.1,
134+
):
135+
assert response2.text
136+
assert len(chat.message_history) == 4
137+
assert chat.message_history[2].author == chat.USER_AUTHOR
138+
assert chat.message_history[2].content == message2
139+
assert chat.message_history[3].author == chat.MODEL_AUTHOR
140+
103141
def test_text_embedding(self):
104142
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
105143

tests/unit/aiplatform/test_language_models.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,33 @@
228228
],
229229
}
230230

231+
_TEST_CHAT_PREDICTION_STREAMING = [
232+
{
233+
"candidates": [
234+
{
235+
"author": "1",
236+
"content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.",
237+
}
238+
],
239+
"safetyAttributes": [{"blocked": False, "categories": None, "scores": None}],
240+
},
241+
{
242+
"candidates": [
243+
{
244+
"author": "1",
245+
"content": " 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.",
246+
}
247+
],
248+
"safetyAttributes": [
249+
{
250+
"blocked": True,
251+
"categories": ["Finance"],
252+
"scores": [0.1],
253+
}
254+
],
255+
},
256+
]
257+
231258
_TEST_CODE_GENERATION_PREDICTION = {
232259
"safetyAttributes": {
233260
"categories": [],
@@ -1735,6 +1762,86 @@ def test_chat_ga(self):
17351762
assert prediction_parameters["topK"] == message_top_k
17361763
assert prediction_parameters["topP"] == message_top_p
17371764

1765+
def test_chat_model_send_message_streaming(self):
1766+
"""Tests the chat generation model."""
1767+
with mock.patch.object(
1768+
target=model_garden_service_client.ModelGardenServiceClient,
1769+
attribute="get_publisher_model",
1770+
return_value=gca_publisher_model.PublisherModel(
1771+
_CHAT_BISON_PUBLISHER_MODEL_DICT
1772+
),
1773+
):
1774+
model = language_models.ChatModel.from_pretrained("chat-bison@001")
1775+
1776+
chat = model.start_chat(
1777+
context="""
1778+
My name is Ned.
1779+
You are my personal assistant.
1780+
My favorite movies are Lord of the Rings and Hobbit.
1781+
""",
1782+
examples=[
1783+
language_models.InputOutputTextPair(
1784+
input_text="Who do you work for?",
1785+
output_text="I work for Ned.",
1786+
),
1787+
language_models.InputOutputTextPair(
1788+
input_text="What do I like?",
1789+
output_text="Ned likes watching movies.",
1790+
),
1791+
],
1792+
message_history=[
1793+
language_models.ChatMessage(
1794+
author=preview_language_models.ChatSession.USER_AUTHOR,
1795+
content="Question 1?",
1796+
),
1797+
language_models.ChatMessage(
1798+
author=preview_language_models.ChatSession.MODEL_AUTHOR,
1799+
content="Answer 1.",
1800+
),
1801+
],
1802+
temperature=0.0,
1803+
)
1804+
1805+
# Using list instead of a generator so that it can be reused.
1806+
response_generator = [
1807+
gca_prediction_service.StreamingPredictResponse(
1808+
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
1809+
)
1810+
for response_dict in _TEST_CHAT_PREDICTION_STREAMING
1811+
]
1812+
1813+
message_temperature = 0.2
1814+
message_max_output_tokens = 200
1815+
message_top_k = 2
1816+
message_top_p = 0.2
1817+
1818+
with mock.patch.object(
1819+
target=prediction_service_client.PredictionServiceClient,
1820+
attribute="server_streaming_predict",
1821+
return_value=response_generator,
1822+
):
1823+
message_text1 = "Are my favorite movies based on a book series?"
1824+
1825+
for idx, response in enumerate(
1826+
chat.send_message_streaming(
1827+
message=message_text1,
1828+
max_output_tokens=message_max_output_tokens,
1829+
temperature=message_temperature,
1830+
top_k=message_top_k,
1831+
top_p=message_top_p,
1832+
)
1833+
):
1834+
assert len(response.text) > 10
1835+
# New messages are not added until the response is fully read
1836+
if idx + 1 < len(response_generator):
1837+
assert len(chat.message_history) == 2
1838+
1839+
# New messages are only added after the response is fully read
1840+
assert len(chat.message_history) == 4
1841+
assert chat.message_history[2].author == chat.USER_AUTHOR
1842+
assert chat.message_history[2].content == message_text1
1843+
assert chat.message_history[3].author == chat.MODEL_AUTHOR
1844+
17381845
def test_code_chat(self):
17391846
"""Tests the code chat model."""
17401847
aiplatform.init(

vertexai/language_models/_language_models.py

Lines changed: 144 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def _model_resource_name(self) -> str:
8888
return self._endpoint.list_models()[0].model
8989

9090

91+
@dataclasses.dataclass
92+
class _PredictionRequest:
93+
"""A single-instance prediction request."""
94+
instance: Dict[str, Any]
95+
parameters: Optional[Dict[str, Any]] = None
96+
97+
9198
class _TunableModelMixin(_LanguageModel):
9299
"""Model that can be tuned."""
93100

@@ -915,16 +922,16 @@ def message_history(self) -> List[ChatMessage]:
915922
"""List of previous messages."""
916923
return self._message_history
917924

918-
def send_message(
925+
def _prepare_request(
919926
self,
920927
message: str,
921928
*,
922929
max_output_tokens: Optional[int] = None,
923930
temperature: Optional[float] = None,
924931
top_k: Optional[int] = None,
925932
top_p: Optional[float] = None,
926-
) -> "TextGenerationResponse":
927-
"""Sends message to the language model and gets a response.
933+
) -> _PredictionRequest:
934+
"""Prepares a request for the language model.
928935
929936
Args:
930937
message: Message to send to the model
@@ -938,7 +945,7 @@ def send_message(
938945
Uses the value specified when calling `ChatModel.start_chat` by default.
939946
940947
Returns:
941-
A `TextGenerationResponse` object that contains the text produced by the model.
948+
A `_PredictionRequest` object.
942949
"""
943950
prediction_parameters = {}
944951

@@ -986,27 +993,87 @@ def send_message(
986993
for example in self._examples
987994
]
988995

989-
prediction_response = self._model._endpoint.predict(
990-
instances=[prediction_instance],
996+
return _PredictionRequest(
997+
instance=prediction_instance,
991998
parameters=prediction_parameters,
992999
)
9931000

994-
prediction = prediction_response.predictions[0]
1001+
@classmethod
1002+
def _parse_chat_prediction_response(
1003+
cls,
1004+
prediction_response: aiplatform.models.Prediction,
1005+
prediction_idx: int = 0,
1006+
candidate_idx: int = 0,
1007+
) -> TextGenerationResponse:
1008+
"""Parses prediction response for chat models.
1009+
1010+
Args:
1011+
prediction_response: Prediction response received from the model
1012+
prediction_idx: Index of the prediction to parse.
1013+
candidate_idx: Index of the candidate to parse.
1014+
1015+
Returns:
1016+
A `TextGenerationResponse` object.
1017+
"""
1018+
prediction = prediction_response.predictions[prediction_idx]
9951019
# ! Note: For chat models, the safetyAttributes is a list.
996-
safety_attributes = prediction["safetyAttributes"][0]
997-
response_obj = TextGenerationResponse(
998-
text=prediction["candidates"][0]["content"]
1020+
safety_attributes = prediction["safetyAttributes"][candidate_idx]
1021+
return TextGenerationResponse(
1022+
text=prediction["candidates"][candidate_idx]["content"]
9991023
if prediction.get("candidates")
10001024
else None,
10011025
_prediction_response=prediction_response,
10021026
is_blocked=safety_attributes.get("blocked", False),
10031027
safety_attributes=dict(
10041028
zip(
1005-
safety_attributes.get("categories", []),
1006-
safety_attributes.get("scores", []),
1029+
# Unlike with normal prediction, in streaming prediction
1030+
# categories and scores can be None
1031+
safety_attributes.get("categories") or [],
1032+
safety_attributes.get("scores") or [],
10071033
)
10081034
),
10091035
)
1036+
1037+
def send_message(
1038+
self,
1039+
message: str,
1040+
*,
1041+
max_output_tokens: Optional[int] = None,
1042+
temperature: Optional[float] = None,
1043+
top_k: Optional[int] = None,
1044+
top_p: Optional[float] = None,
1045+
) -> "TextGenerationResponse":
1046+
"""Sends message to the language model and gets a response.
1047+
1048+
Args:
1049+
message: Message to send to the model
1050+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1051+
Uses the value specified when calling `ChatModel.start_chat` by default.
1052+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1053+
Uses the value specified when calling `ChatModel.start_chat` by default.
1054+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1055+
Uses the value specified when calling `ChatModel.start_chat` by default.
1056+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1057+
Uses the value specified when calling `ChatModel.start_chat` by default.
1058+
1059+
Returns:
1060+
A `TextGenerationResponse` object that contains the text produced by the model.
1061+
"""
1062+
prediction_request = self._prepare_request(
1063+
message=message,
1064+
max_output_tokens=max_output_tokens,
1065+
temperature=temperature,
1066+
top_k=top_k,
1067+
top_p=top_p,
1068+
)
1069+
1070+
prediction_response = self._model._endpoint.predict(
1071+
instances=[prediction_request.instance],
1072+
parameters=prediction_request.parameters,
1073+
)
1074+
response_obj = self._parse_chat_prediction_response(
1075+
prediction_response=prediction_response
1076+
)
10101077
response_text = response_obj.text
10111078

10121079
self._message_history.append(
@@ -1018,6 +1085,71 @@ def send_message(
10181085

10191086
return response_obj
10201087

1088+
def send_message_streaming(
1089+
self,
1090+
message: str,
1091+
*,
1092+
max_output_tokens: Optional[int] = None,
1093+
temperature: Optional[float] = None,
1094+
top_k: Optional[int] = None,
1095+
top_p: Optional[float] = None,
1096+
) -> Iterator[TextGenerationResponse]:
1097+
"""Sends message to the language model and gets a streamed response.
1098+
1099+
The response is only added to the history once it's fully read.
1100+
1101+
Args:
1102+
message: Message to send to the model
1103+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1104+
Uses the value specified when calling `ChatModel.start_chat` by default.
1105+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1106+
Uses the value specified when calling `ChatModel.start_chat` by default.
1107+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1108+
Uses the value specified when calling `ChatModel.start_chat` by default.
1109+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1110+
Uses the value specified when calling `ChatModel.start_chat` by default.
1111+
1112+
Yields:
1113+
A stream of `TextGenerationResponse` objects that contain partial
1114+
responses produced by the model.
1115+
"""
1116+
prediction_request = self._prepare_request(
1117+
message=message,
1118+
max_output_tokens=max_output_tokens,
1119+
temperature=temperature,
1120+
top_k=top_k,
1121+
top_p=top_p,
1122+
)
1123+
1124+
prediction_service_client = self._model._endpoint._prediction_client
1125+
1126+
full_response_text = ""
1127+
1128+
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
1129+
prediction_service_client=prediction_service_client,
1130+
endpoint_name=self._model._endpoint_name,
1131+
instance=prediction_request.instance,
1132+
parameters=prediction_request.parameters,
1133+
):
1134+
prediction_response = aiplatform.models.Prediction(
1135+
predictions=[prediction_dict],
1136+
deployed_model_id="",
1137+
)
1138+
text_generation_response = self._parse_chat_prediction_response(
1139+
prediction_response=prediction_response
1140+
)
1141+
full_response_text += text_generation_response.text
1142+
yield text_generation_response
1143+
1144+
# We only add the question and answer to the history if/when the answer
1145+
# was read fully. Otherwise, the answer would have been truncated.
1146+
self._message_history.append(
1147+
ChatMessage(content=message, author=self.USER_AUTHOR)
1148+
)
1149+
self._message_history.append(
1150+
ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
1151+
)
1152+
10211153

10221154
class ChatSession(_ChatSessionBase):
10231155
"""ChatSession represents a chat session with a language model.

0 commit comments

Comments
 (0)