Skip to content

Commit 760a025

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for async streaming
PiperOrigin-RevId: 573094790
1 parent 7944348 commit 760a025

File tree

3 files changed

+317
-2
lines changed

3 files changed

+317
-2
lines changed

google/cloud/aiplatform/_streaming_prediction.py

+87-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717
"""Streaming prediction functions."""
1818

19-
from typing import Any, Dict, Iterator, List, Optional, Sequence
19+
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence
2020

2121
from google.cloud.aiplatform_v1.services import prediction_service
2222
from google.cloud.aiplatform_v1.types import (
@@ -108,6 +108,34 @@ def predict_stream_of_tensor_lists_from_single_tensor_list(
108108
yield response.outputs
109109

110110

111+
async def predict_stream_of_tensor_lists_from_single_tensor_list_async(
112+
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
113+
endpoint_name: str,
114+
tensor_list: List[aiplatform_types.Tensor],
115+
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
116+
) -> AsyncIterator[List[aiplatform_types.Tensor]]:
117+
"""Asynchronously predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
118+
119+
Args:
120+
tensor_list: Model input as a list of `Tensor` objects.
121+
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
122+
prediction_service_async_client: A PredictionServiceAsyncClient object.
123+
endpoint_name: Resource name of Endpoint or PublisherModel.
124+
125+
Yields:
126+
A generator of model prediction `Tensor` lists.
127+
"""
128+
request = prediction_service_types.StreamingPredictRequest(
129+
endpoint=endpoint_name,
130+
inputs=tensor_list,
131+
parameters=parameters_tensor,
132+
)
133+
async for response in prediction_service_async_client.server_streaming_predict(
134+
request=request
135+
):
136+
yield response.outputs
137+
138+
111139
def predict_stream_of_dict_lists_from_single_dict_list(
112140
prediction_service_client: prediction_service.PredictionServiceClient,
113141
endpoint_name: str,
@@ -136,6 +164,34 @@ def predict_stream_of_dict_lists_from_single_dict_list(
136164
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
137165

138166

167+
async def predict_stream_of_dict_lists_from_single_dict_list_async(
168+
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
169+
endpoint_name: str,
170+
dict_list: List[Dict[str, Any]],
171+
parameters: Optional[Dict[str, Any]] = None,
172+
) -> AsyncIterator[List[Dict[str, Any]]]:
173+
"""Asynchronously predicts a stream of lists of dicts from a stream of lists of dicts.
174+
175+
Args:
176+
dict_list: Model input as a list of `dict` objects.
177+
parameters: Optional. Prediction parameters `dict` form.
178+
prediction_service_async_client: A PredictionServiceAsyncClient object.
179+
endpoint_name: Resource name of Endpoint or PublisherModel.
180+
181+
Yields:
182+
A generator of model prediction dict lists.
183+
"""
184+
tensor_list = [value_to_tensor(d) for d in dict_list]
185+
parameters_tensor = value_to_tensor(parameters) if parameters else None
186+
async for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list_async(
187+
prediction_service_async_client=prediction_service_async_client,
188+
endpoint_name=endpoint_name,
189+
tensor_list=tensor_list,
190+
parameters_tensor=parameters_tensor,
191+
):
192+
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
193+
194+
139195
def predict_stream_of_dicts_from_single_dict(
140196
prediction_service_client: prediction_service.PredictionServiceClient,
141197
endpoint_name: str,
@@ -164,3 +220,33 @@ def predict_stream_of_dicts_from_single_dict(
164220
f"Expected to receive a single output, but got {dict_list}"
165221
)
166222
yield dict_list[0]
223+
224+
225+
async def predict_stream_of_dicts_from_single_dict_async(
226+
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
227+
endpoint_name: str,
228+
instance: Dict[str, Any],
229+
parameters: Optional[Dict[str, Any]] = None,
230+
) -> AsyncIterator[Dict[str, Any]]:
231+
"""Asynchronously predicts a stream of dicts from a single instance dict.
232+
233+
Args:
234+
instance: A single input instance `dict`.
235+
parameters: Optional. Prediction parameters `dict`.
236+
prediction_service_async_client: A PredictionServiceAsyncClient object.
237+
endpoint_name: Resource name of Endpoint or PublisherModel.
238+
239+
Yields:
240+
A generator of model prediction dicts.
241+
"""
242+
async for dict_list in predict_stream_of_dict_lists_from_single_dict_list_async(
243+
prediction_service_async_client=prediction_service_async_client,
244+
endpoint_name=endpoint_name,
245+
dict_list=[instance],
246+
parameters=parameters,
247+
):
248+
if len(dict_list) > 1:
249+
raise ValueError(
250+
f"Expected to receive a single output, but got {dict_list}"
251+
)
252+
yield dict_list[0]

tests/unit/aiplatform/test_language_models.py

+35
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,41 @@ def test_text_generation_model_predict_streaming(self):
14651465
):
14661466
assert len(response.text) > 10
14671467

1468+
@pytest.mark.asyncio
1469+
async def test_text_generation_model_predict_streaming_async(self):
1470+
"""Tests the TextGenerationModel.predict_streaming_async method."""
1471+
with mock.patch.object(
1472+
target=model_garden_service_client.ModelGardenServiceClient,
1473+
attribute="get_publisher_model",
1474+
return_value=gca_publisher_model.PublisherModel(
1475+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1476+
),
1477+
):
1478+
model = language_models.TextGenerationModel.from_pretrained(
1479+
"text-bison@001"
1480+
)
1481+
1482+
async def mock_server_streaming_predict_async(*args, **kwargs):
1483+
for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING:
1484+
yield gca_prediction_service.StreamingPredictResponse(
1485+
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
1486+
)
1487+
1488+
with mock.patch.object(
1489+
target=prediction_service_async_client.PredictionServiceAsyncClient,
1490+
attribute="server_streaming_predict",
1491+
new=mock_server_streaming_predict_async,
1492+
):
1493+
async for response in model.predict_streaming_async(
1494+
"Count to 50",
1495+
max_output_tokens=1000,
1496+
temperature=0.0,
1497+
top_p=1.0,
1498+
top_k=5,
1499+
stop_sequences=["# %%"],
1500+
):
1501+
assert len(response.text) > 10
1502+
14681503
def test_text_generation_response_repr(self):
14691504
response = language_models.TextGenerationResponse(
14701505
text="",

vertexai/language_models/_language_models.py

+195-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Classes for working with language models."""
1616

1717
import dataclasses
18-
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
18+
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Union
1919
import warnings
2020

2121
from google.cloud import aiplatform
@@ -871,6 +871,54 @@ def predict_streaming(
871871
)
872872
yield _parse_text_generation_model_response(prediction_obj)
873873

874+
async def predict_streaming_async(
875+
self,
876+
prompt: str,
877+
*,
878+
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
879+
temperature: Optional[float] = None,
880+
top_k: Optional[int] = None,
881+
top_p: Optional[float] = None,
882+
stop_sequences: Optional[List[str]] = None,
883+
) -> AsyncIterator[TextGenerationResponse]:
884+
"""Asynchronously gets a streaming model response for a single prompt.
885+
886+
The result is a stream (generator) of partial responses.
887+
888+
Args:
889+
prompt: Question to ask the model.
890+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
891+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
892+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
893+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
894+
stop_sequences: Customized stop sequences to stop the decoding process.
895+
896+
Yields:
897+
A stream of `TextGenerationResponse` objects that contain partial
898+
responses produced by the model.
899+
"""
900+
prediction_request = _create_text_generation_prediction_request(
901+
prompt=prompt,
902+
max_output_tokens=max_output_tokens,
903+
temperature=temperature,
904+
top_k=top_k,
905+
top_p=top_p,
906+
stop_sequences=stop_sequences,
907+
)
908+
909+
prediction_service_async_client = self._endpoint._prediction_async_client
910+
async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async(
911+
prediction_service_async_client=prediction_service_async_client,
912+
endpoint_name=self._endpoint_name,
913+
instance=prediction_request.instance,
914+
parameters=prediction_request.parameters,
915+
):
916+
prediction_obj = aiplatform.models.Prediction(
917+
predictions=[prediction_dict],
918+
deployed_model_id="",
919+
)
920+
yield _parse_text_generation_model_response(prediction_obj)
921+
874922

875923
def _create_text_generation_prediction_request(
876924
prompt: str,
@@ -1928,6 +1976,75 @@ def send_message_streaming(
19281976
ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
19291977
)
19301978

1979+
async def send_message_streaming_async(
1980+
self,
1981+
message: str,
1982+
*,
1983+
max_output_tokens: Optional[int] = None,
1984+
temperature: Optional[float] = None,
1985+
top_k: Optional[int] = None,
1986+
top_p: Optional[float] = None,
1987+
stop_sequences: Optional[List[str]] = None,
1988+
) -> AsyncIterator[TextGenerationResponse]:
1989+
"""Asynchronously sends message to the language model and gets a streamed response.
1990+
1991+
The response is only added to the history once it's fully read.
1992+
1993+
Args:
1994+
message: Message to send to the model
1995+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1996+
Uses the value specified when calling `ChatModel.start_chat` by default.
1997+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1998+
Uses the value specified when calling `ChatModel.start_chat` by default.
1999+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
2000+
Uses the value specified when calling `ChatModel.start_chat` by default.
2001+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
2002+
Uses the value specified when calling `ChatModel.start_chat` by default.
2003+
stop_sequences: Customized stop sequences to stop the decoding process.
2004+
Uses the value specified when calling `ChatModel.start_chat` by default.
2005+
2006+
Yields:
2007+
A stream of `TextGenerationResponse` objects that contain partial
2008+
responses produced by the model.
2009+
"""
2010+
prediction_request = self._prepare_request(
2011+
message=message,
2012+
max_output_tokens=max_output_tokens,
2013+
temperature=temperature,
2014+
top_k=top_k,
2015+
top_p=top_p,
2016+
stop_sequences=stop_sequences,
2017+
)
2018+
2019+
prediction_service_async_client = self._model._endpoint._prediction_async_client
2020+
2021+
full_response_text = ""
2022+
2023+
async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async(
2024+
prediction_service_async_client=prediction_service_async_client,
2025+
endpoint_name=self._model._endpoint_name,
2026+
instance=prediction_request.instance,
2027+
parameters=prediction_request.parameters,
2028+
):
2029+
prediction_response = aiplatform.models.Prediction(
2030+
predictions=[prediction_dict],
2031+
deployed_model_id="",
2032+
)
2033+
text_generation_response = self._parse_chat_prediction_response(
2034+
prediction_response=prediction_response
2035+
)
2036+
full_response_text += text_generation_response.text
2037+
yield text_generation_response
2038+
2039+
# We only add the question and answer to the history if/when the answer
2040+
# was read fully. Otherwise, the answer would have been truncated.
2041+
self._message_history.append(
2042+
ChatMessage(content=message, author=self.USER_AUTHOR)
2043+
)
2044+
self._message_history.append(
2045+
ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
2046+
)
2047+
19312048

19322049
class ChatSession(_ChatSessionBase):
19332050
"""ChatSession represents a chat session with a language model.
@@ -2073,6 +2190,38 @@ def send_message_streaming(
20732190
stop_sequences=stop_sequences,
20742191
)
20752192

2193+
def send_message_streaming_async(
2194+
self,
2195+
message: str,
2196+
*,
2197+
max_output_tokens: Optional[int] = None,
2198+
temperature: Optional[float] = None,
2199+
stop_sequences: Optional[List[str]] = None,
2200+
) -> AsyncIterator[TextGenerationResponse]:
2201+
"""Asynchronously sends message to the language model and gets a streamed response.
2202+
2203+
The response is only added to the history once it's fully read.
2204+
2205+
Args:
2206+
message: Message to send to the model
2207+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
2208+
Uses the value specified when calling `ChatModel.start_chat` by default.
2209+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
2210+
Uses the value specified when calling `ChatModel.start_chat` by default.
2211+
stop_sequences: Customized stop sequences to stop the decoding process.
2212+
Uses the value specified when calling `ChatModel.start_chat` by default.
2213+
2214+
Returns:
2215+
A stream of `TextGenerationResponse` objects that contain partial
2216+
responses produced by the model.
2217+
"""
2218+
return super().send_message_streaming_async(
2219+
message=message,
2220+
max_output_tokens=max_output_tokens,
2221+
temperature=temperature,
2222+
stop_sequences=stop_sequences,
2223+
)
2224+
20762225

20772226
class CodeGenerationModel(_LanguageModel):
20782227
"""A language model that generates code.
@@ -2255,6 +2404,51 @@ def predict_streaming(
22552404
)
22562405
yield _parse_text_generation_model_response(prediction_obj)
22572406

2407+
async def predict_streaming_async(
2408+
self,
2409+
prefix: str,
2410+
suffix: Optional[str] = None,
2411+
*,
2412+
max_output_tokens: Optional[int] = None,
2413+
temperature: Optional[float] = None,
2414+
stop_sequences: Optional[List[str]] = None,
2415+
) -> AsyncIterator[TextGenerationResponse]:
2416+
"""Asynchronously predicts the code based on previous code.
2417+
2418+
The result is a stream (generator) of partial responses.
2419+
2420+
Args:
2421+
prefix: Code before the current point.
2422+
suffix: Code after the current point.
2423+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
2424+
temperature: Controls the randomness of predictions. Range: [0, 1].
2425+
stop_sequences: Customized stop sequences to stop the decoding process.
2426+
2427+
Yields:
2428+
A stream of `TextGenerationResponse` objects that contain partial
2429+
responses produced by the model.
2430+
"""
2431+
prediction_request = self._create_prediction_request(
2432+
prefix=prefix,
2433+
suffix=suffix,
2434+
max_output_tokens=max_output_tokens,
2435+
temperature=temperature,
2436+
stop_sequences=stop_sequences,
2437+
)
2438+
2439+
prediction_service_async_client = self._endpoint._prediction_async_client
2440+
async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async(
2441+
prediction_service_async_client=prediction_service_async_client,
2442+
endpoint_name=self._endpoint_name,
2443+
instance=prediction_request.instance,
2444+
parameters=prediction_request.parameters,
2445+
):
2446+
prediction_obj = aiplatform.models.Prediction(
2447+
predictions=[prediction_dict],
2448+
deployed_model_id="",
2449+
)
2450+
yield _parse_text_generation_model_response(prediction_obj)
2451+
22582452

22592453
class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
22602454
__name__ = "CodeGenerationModel"

0 commit comments

Comments
 (0)