|
15 | 15 | """Classes for working with language models."""
|
16 | 16 |
|
17 | 17 | import dataclasses
|
18 |
| -from typing import Any, Dict, Iterator, List, Optional, Sequence, Union |
| 18 | +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Union |
19 | 19 | import warnings
|
20 | 20 |
|
21 | 21 | from google.cloud import aiplatform
|
@@ -871,6 +871,54 @@ def predict_streaming(
|
871 | 871 | )
|
872 | 872 | yield _parse_text_generation_model_response(prediction_obj)
|
873 | 873 |
|
| 874 | + async def predict_streaming_async( |
| 875 | + self, |
| 876 | + prompt: str, |
| 877 | + *, |
| 878 | + max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, |
| 879 | + temperature: Optional[float] = None, |
| 880 | + top_k: Optional[int] = None, |
| 881 | + top_p: Optional[float] = None, |
| 882 | + stop_sequences: Optional[List[str]] = None, |
| 883 | + ) -> AsyncIterator[TextGenerationResponse]: |
| 884 | + """Asynchronously gets a streaming model response for a single prompt. |
| 885 | +
|
| 886 | + The result is a stream (generator) of partial responses. |
| 887 | +
|
| 888 | + Args: |
| 889 | + prompt: Question to ask the model. |
| 890 | + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. |
| 891 | + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. |
| 892 | + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. |
| 893 | + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. |
| 894 | + stop_sequences: Customized stop sequences to stop the decoding process. |
| 895 | +
|
| 896 | + Yields: |
| 897 | + A stream of `TextGenerationResponse` objects that contain partial |
| 898 | + responses produced by the model. |
| 899 | + """ |
| 900 | + prediction_request = _create_text_generation_prediction_request( |
| 901 | + prompt=prompt, |
| 902 | + max_output_tokens=max_output_tokens, |
| 903 | + temperature=temperature, |
| 904 | + top_k=top_k, |
| 905 | + top_p=top_p, |
| 906 | + stop_sequences=stop_sequences, |
| 907 | + ) |
| 908 | + |
| 909 | + prediction_service_async_client = self._endpoint._prediction_async_client |
| 910 | + async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async( |
| 911 | + prediction_service_async_client=prediction_service_async_client, |
| 912 | + endpoint_name=self._endpoint_name, |
| 913 | + instance=prediction_request.instance, |
| 914 | + parameters=prediction_request.parameters, |
| 915 | + ): |
| 916 | + prediction_obj = aiplatform.models.Prediction( |
| 917 | + predictions=[prediction_dict], |
| 918 | + deployed_model_id="", |
| 919 | + ) |
| 920 | + yield _parse_text_generation_model_response(prediction_obj) |
| 921 | + |
874 | 922 |
|
875 | 923 | def _create_text_generation_prediction_request(
|
876 | 924 | prompt: str,
|
@@ -1928,6 +1976,75 @@ def send_message_streaming(
|
1928 | 1976 | ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
|
1929 | 1977 | )
|
1930 | 1978 |
|
| 1979 | + async def send_message_streaming_async( |
| 1980 | + self, |
| 1981 | + message: str, |
| 1982 | + *, |
| 1983 | + max_output_tokens: Optional[int] = None, |
| 1984 | + temperature: Optional[float] = None, |
| 1985 | + top_k: Optional[int] = None, |
| 1986 | + top_p: Optional[float] = None, |
| 1987 | + stop_sequences: Optional[List[str]] = None, |
| 1988 | + ) -> AsyncIterator[TextGenerationResponse]: |
| 1989 | + """Asynchronously sends message to the language model and gets a streamed response. |
| 1990 | +
|
| 1991 | + The response is only added to the history once it's fully read. |
| 1992 | +
|
| 1993 | + Args: |
| 1994 | + message: Message to send to the model |
| 1995 | + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. |
| 1996 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 1997 | + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. |
| 1998 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 1999 | + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. |
| 2000 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 2001 | + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. |
| 2002 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 2003 | + stop_sequences: Customized stop sequences to stop the decoding process. |
| 2004 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 2005 | +
|
| 2006 | + Yields: |
| 2007 | + A stream of `TextGenerationResponse` objects that contain partial |
| 2008 | + responses produced by the model. |
| 2009 | + """ |
| 2010 | + prediction_request = self._prepare_request( |
| 2011 | + message=message, |
| 2012 | + max_output_tokens=max_output_tokens, |
| 2013 | + temperature=temperature, |
| 2014 | + top_k=top_k, |
| 2015 | + top_p=top_p, |
| 2016 | + stop_sequences=stop_sequences, |
| 2017 | + ) |
| 2018 | + |
| 2019 | + prediction_service_async_client = self._model._endpoint._prediction_async_client |
| 2020 | + |
| 2021 | + full_response_text = "" |
| 2022 | + |
| 2023 | + async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async( |
| 2024 | + prediction_service_async_client=prediction_service_async_client, |
| 2025 | + endpoint_name=self._model._endpoint_name, |
| 2026 | + instance=prediction_request.instance, |
| 2027 | + parameters=prediction_request.parameters, |
| 2028 | + ): |
| 2029 | + prediction_response = aiplatform.models.Prediction( |
| 2030 | + predictions=[prediction_dict], |
| 2031 | + deployed_model_id="", |
| 2032 | + ) |
| 2033 | + text_generation_response = self._parse_chat_prediction_response( |
| 2034 | + prediction_response=prediction_response |
| 2035 | + ) |
| 2036 | + full_response_text += text_generation_response.text |
| 2037 | + yield text_generation_response |
| 2038 | + |
| 2039 | + # We only add the question and answer to the history if/when the answer |
| 2040 | + # was read fully. Otherwise, the answer would have been truncated. |
| 2041 | + self._message_history.append( |
| 2042 | + ChatMessage(content=message, author=self.USER_AUTHOR) |
| 2043 | + ) |
| 2044 | + self._message_history.append( |
| 2045 | + ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR) |
| 2046 | + ) |
| 2047 | + |
1931 | 2048 |
|
1932 | 2049 | class ChatSession(_ChatSessionBase):
|
1933 | 2050 | """ChatSession represents a chat session with a language model.
|
@@ -2073,6 +2190,38 @@ def send_message_streaming(
|
2073 | 2190 | stop_sequences=stop_sequences,
|
2074 | 2191 | )
|
2075 | 2192 |
|
| 2193 | + def send_message_streaming_async( |
| 2194 | + self, |
| 2195 | + message: str, |
| 2196 | + *, |
| 2197 | + max_output_tokens: Optional[int] = None, |
| 2198 | + temperature: Optional[float] = None, |
| 2199 | + stop_sequences: Optional[List[str]] = None, |
| 2200 | + ) -> AsyncIterator[TextGenerationResponse]: |
| 2201 | + """Asynchronously sends message to the language model and gets a streamed response. |
| 2202 | +
|
| 2203 | + The response is only added to the history once it's fully read. |
| 2204 | +
|
| 2205 | + Args: |
| 2206 | + message: Message to send to the model |
| 2207 | + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. |
| 2208 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 2209 | + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. |
| 2210 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 2211 | + stop_sequences: Customized stop sequences to stop the decoding process. |
| 2212 | + Uses the value specified when calling `ChatModel.start_chat` by default. |
| 2213 | +
|
| 2214 | + Returns: |
| 2215 | + A stream of `TextGenerationResponse` objects that contain partial |
| 2216 | + responses produced by the model. |
| 2217 | + """ |
| 2218 | + return super().send_message_streaming_async( |
| 2219 | + message=message, |
| 2220 | + max_output_tokens=max_output_tokens, |
| 2221 | + temperature=temperature, |
| 2222 | + stop_sequences=stop_sequences, |
| 2223 | + ) |
| 2224 | + |
2076 | 2225 |
|
2077 | 2226 | class CodeGenerationModel(_LanguageModel):
|
2078 | 2227 | """A language model that generates code.
|
@@ -2255,6 +2404,51 @@ def predict_streaming(
|
2255 | 2404 | )
|
2256 | 2405 | yield _parse_text_generation_model_response(prediction_obj)
|
2257 | 2406 |
|
| 2407 | + async def predict_streaming_async( |
| 2408 | + self, |
| 2409 | + prefix: str, |
| 2410 | + suffix: Optional[str] = None, |
| 2411 | + *, |
| 2412 | + max_output_tokens: Optional[int] = None, |
| 2413 | + temperature: Optional[float] = None, |
| 2414 | + stop_sequences: Optional[List[str]] = None, |
| 2415 | + ) -> AsyncIterator[TextGenerationResponse]: |
| 2416 | + """Asynchronously predicts the code based on previous code. |
| 2417 | +
|
| 2418 | + The result is a stream (generator) of partial responses. |
| 2419 | +
|
| 2420 | + Args: |
| 2421 | + prefix: Code before the current point. |
| 2422 | + suffix: Code after the current point. |
| 2423 | + max_output_tokens: Max length of the output text in tokens. Range: [1, 1000]. |
| 2424 | + temperature: Controls the randomness of predictions. Range: [0, 1]. |
| 2425 | + stop_sequences: Customized stop sequences to stop the decoding process. |
| 2426 | +
|
| 2427 | + Yields: |
| 2428 | + A stream of `TextGenerationResponse` objects that contain partial |
| 2429 | + responses produced by the model. |
| 2430 | + """ |
| 2431 | + prediction_request = self._create_prediction_request( |
| 2432 | + prefix=prefix, |
| 2433 | + suffix=suffix, |
| 2434 | + max_output_tokens=max_output_tokens, |
| 2435 | + temperature=temperature, |
| 2436 | + stop_sequences=stop_sequences, |
| 2437 | + ) |
| 2438 | + |
| 2439 | + prediction_service_async_client = self._endpoint._prediction_async_client |
| 2440 | + async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async( |
| 2441 | + prediction_service_async_client=prediction_service_async_client, |
| 2442 | + endpoint_name=self._endpoint_name, |
| 2443 | + instance=prediction_request.instance, |
| 2444 | + parameters=prediction_request.parameters, |
| 2445 | + ): |
| 2446 | + prediction_obj = aiplatform.models.Prediction( |
| 2447 | + predictions=[prediction_dict], |
| 2448 | + deployed_model_id="", |
| 2449 | + ) |
| 2450 | + yield _parse_text_generation_model_response(prediction_obj) |
| 2451 | + |
2258 | 2452 |
|
2259 | 2453 | class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
|
2260 | 2454 | __name__ = "CodeGenerationModel"
|
|
0 commit comments