Skip to content

Commit 0359f1d

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support streaming prediction for code chat models
PiperOrigin-RevId: 558364254
1 parent 3a8348b commit 0359f1d

File tree

3 files changed

+85
-8
lines changed

3 files changed

+85
-8
lines changed

tests/system/aiplatform/test_language_models.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,19 @@ def test_code_generation_streaming(self):
260260

261261
for response in model.predict_streaming(
262262
prefix="def reverse_string(s):",
263-
suffix=" return s",
263+
# code-bison does not support suffix
264+
# suffix=" return s",
264265
max_output_tokens=128,
265266
temperature=0,
266267
):
267268
assert response.text
269+
270+
def test_code_chat_model_send_message_streaming(self):
271+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
272+
273+
chat_model = language_models.ChatModel.from_pretrained("codeodechat-bison@001")
274+
chat = chat_model.start_chat()
275+
276+
message1 = "Please help write a function to calculate the max of two numbers"
277+
for response in chat.send_message_streaming(message1):
278+
assert response.text

tests/unit/aiplatform/test_language_models.py

+45
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,51 @@ def test_code_chat(self):
19381938
assert prediction_parameters["temperature"] == message_temperature
19391939
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
19401940

1941+
def test_code_chat_model_send_message_streaming(self):
1942+
"""Tests the chat generation model."""
1943+
aiplatform.init(
1944+
project=_TEST_PROJECT,
1945+
location=_TEST_LOCATION,
1946+
)
1947+
with mock.patch.object(
1948+
target=model_garden_service_client.ModelGardenServiceClient,
1949+
attribute="get_publisher_model",
1950+
return_value=gca_publisher_model.PublisherModel(
1951+
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
1952+
),
1953+
):
1954+
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")
1955+
1956+
chat = model.start_chat(temperature=0.0)
1957+
1958+
# Using list instead of a generator so that it can be reused.
1959+
response_generator = [
1960+
gca_prediction_service.StreamingPredictResponse(
1961+
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
1962+
)
1963+
for response_dict in _TEST_CHAT_PREDICTION_STREAMING
1964+
]
1965+
1966+
with mock.patch.object(
1967+
target=prediction_service_client.PredictionServiceClient,
1968+
attribute="server_streaming_predict",
1969+
return_value=response_generator,
1970+
):
1971+
message_text1 = (
1972+
"Please help write a function to calculate the max of two numbers"
1973+
)
1974+
# New messages are not added until the response is fully read
1975+
assert not chat.message_history
1976+
for response in chat.send_message_streaming(message_text1):
1977+
assert len(response.text) > 10
1978+
# New messages are only added after the response is fully read
1979+
assert chat.message_history
1980+
1981+
assert len(chat.message_history) == 2
1982+
assert chat.message_history[0].author == chat.USER_AUTHOR
1983+
assert chat.message_history[0].content == message_text1
1984+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
1985+
19411986
def test_code_generation(self):
19421987
"""Tests code generation with the code generation model."""
19431988
aiplatform.init(

vertexai/language_models/_language_models.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
5959
return f"publishers/google/models/{model_name}@{version}"
6060

6161

62-
@dataclasses.dataclass
63-
class _PredictionRequest:
64-
"""A single-instance prediction request."""
65-
instance: Dict[str, Any]
66-
parameters: Optional[Dict[str, Any]] = None
67-
68-
6962
class _LanguageModel(_model_garden_models._ModelGardenModel):
7063
"""_LanguageModel is a base class for all language models."""
7164

@@ -1234,6 +1227,34 @@ def send_message(
12341227
temperature=temperature,
12351228
)
12361229

1230+
def send_message_streaming(
1231+
self,
1232+
message: str,
1233+
*,
1234+
max_output_tokens: Optional[int] = None,
1235+
temperature: Optional[float] = None,
1236+
) -> Iterator[TextGenerationResponse]:
1237+
"""Sends message to the language model and gets a streamed response.
1238+
1239+
The response is only added to the history once it's fully read.
1240+
1241+
Args:
1242+
message: Message to send to the model
1243+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1244+
Uses the value specified when calling `ChatModel.start_chat` by default.
1245+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1246+
Uses the value specified when calling `ChatModel.start_chat` by default.
1247+
1248+
Returns:
1249+
A stream of `TextGenerationResponse` objects that contain partial
1250+
responses produced by the model.
1251+
"""
1252+
return super().send_message_streaming(
1253+
message=message,
1254+
max_output_tokens=max_output_tokens,
1255+
temperature=temperature,
1256+
)
1257+
12371258

12381259
class CodeGenerationModel(_LanguageModel):
12391260
"""A language model that generates code.

0 commit comments

Comments
 (0)