Skip to content

Commit fb527f3

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support streaming prediction for text generation models
PiperOrigin-RevId: 558068359
1 parent 8df5185 commit fb527f3

File tree

4 files changed

+313
-1
lines changed

4 files changed

+313
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""Streaming prediction functions."""
18+
19+
from typing import Any, Dict, Iterator, List, Optional, Sequence
20+
21+
from google.cloud.aiplatform_v1.services import prediction_service
22+
from google.cloud.aiplatform_v1.types import (
23+
prediction_service as prediction_service_types,
24+
)
25+
from google.cloud.aiplatform_v1.types import (
26+
types as aiplatform_types,
27+
)
28+
29+
30+
def value_to_tensor(value: Any) -> aiplatform_types.Tensor:
31+
"""Converts a Python value to `Tensor`.
32+
33+
Args:
34+
value: A value to convert
35+
36+
Returns:
37+
A `Tensor` object
38+
"""
39+
if value is None:
40+
return aiplatform_types.Tensor()
41+
elif isinstance(value, int):
42+
return aiplatform_types.Tensor(int_val=[value])
43+
elif isinstance(value, float):
44+
return aiplatform_types.Tensor(float_val=[value])
45+
elif isinstance(value, bool):
46+
return aiplatform_types.Tensor(bool_val=[value])
47+
elif isinstance(value, str):
48+
return aiplatform_types.Tensor(string_val=[value])
49+
elif isinstance(value, bytes):
50+
return aiplatform_types.Tensor(bytes_val=[value])
51+
elif isinstance(value, list):
52+
return aiplatform_types.Tensor(list_val=[value_to_tensor(x) for x in value])
53+
elif isinstance(value, dict):
54+
return aiplatform_types.Tensor(
55+
struct_val={k: value_to_tensor(v) for k, v in value.items()}
56+
)
57+
raise TypeError(f"Unsupported value type {type(value)}")
58+
59+
60+
def tensor_to_value(tensor_pb: aiplatform_types.Tensor) -> Any:
61+
"""Converts `Tensor` to a Python value.
62+
63+
Args:
64+
tensor_pb: A `Tensor` object
65+
66+
Returns:
67+
A corresponding Python object
68+
"""
69+
list_of_fields = tensor_pb.ListFields()
70+
if not list_of_fields:
71+
return None
72+
descriptor, value = tensor_pb.ListFields()[0]
73+
if descriptor.name == "list_val":
74+
return [tensor_to_value(x) for x in value]
75+
elif descriptor.name == "struct_val":
76+
return {k: tensor_to_value(v) for k, v in value.items()}
77+
if not isinstance(value, Sequence):
78+
raise TypeError(f"Unexpected non-list tensor value {value}")
79+
if len(value) == 1:
80+
return value[0]
81+
else:
82+
return value
83+
84+
85+
def predict_stream_of_tensor_lists_from_single_tensor_list(
86+
prediction_service_client: prediction_service.PredictionServiceClient,
87+
endpoint_name: str,
88+
tensor_list: List[aiplatform_types.Tensor],
89+
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
90+
) -> Iterator[List[aiplatform_types.Tensor]]:
91+
"""Predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
92+
93+
Args:
94+
tensor_list: Model input as a list of `Tensor` objects.
95+
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
96+
prediction_service_client: A PredictionServiceClient object.
97+
endpoint_name: Resource name of Endpoint or PublisherModel.
98+
99+
Yields:
100+
A generator of model prediction `Tensor` lists.
101+
"""
102+
request = prediction_service_types.StreamingPredictRequest(
103+
endpoint=endpoint_name,
104+
inputs=tensor_list,
105+
parameters=parameters_tensor,
106+
)
107+
for response in prediction_service_client.server_streaming_predict(request=request):
108+
yield response.outputs
109+
110+
111+
def predict_stream_of_dict_lists_from_single_dict_list(
112+
prediction_service_client: prediction_service.PredictionServiceClient,
113+
endpoint_name: str,
114+
dict_list: List[Dict[str, Any]],
115+
parameters: Optional[Dict[str, Any]] = None,
116+
) -> Iterator[List[Dict[str, Any]]]:
117+
"""Predicts a stream of lists of dicts from a stream of lists of dicts.
118+
119+
Args:
120+
dict_list: Model input as a list of `dict` objects.
121+
parameters: Optional. Prediction parameters `dict` form.
122+
prediction_service_client: A PredictionServiceClient object.
123+
endpoint_name: Resource name of Endpoint or PublisherModel.
124+
125+
Yields:
126+
A generator of model prediction dict lists.
127+
"""
128+
tensor_list = [value_to_tensor(d) for d in dict_list]
129+
parameters_tensor = value_to_tensor(parameters) if parameters else None
130+
for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list(
131+
prediction_service_client=prediction_service_client,
132+
endpoint_name=endpoint_name,
133+
tensor_list=tensor_list,
134+
parameters_tensor=parameters_tensor,
135+
):
136+
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
137+
138+
139+
def predict_stream_of_dicts_from_single_dict(
140+
prediction_service_client: prediction_service.PredictionServiceClient,
141+
endpoint_name: str,
142+
instance: Dict[str, Any],
143+
parameters: Optional[Dict[str, Any]] = None,
144+
) -> Iterator[Dict[str, Any]]:
145+
"""Predicts a stream of dicts from a single instance dict.
146+
147+
Args:
148+
instance: A single input instance `dict`.
149+
parameters: Optional. Prediction parameters `dict`.
150+
prediction_service_client: A PredictionServiceClient object.
151+
endpoint_name: Resource name of Endpoint or PublisherModel.
152+
153+
Yields:
154+
A generator of model prediction dicts.
155+
"""
156+
for dict_list in predict_stream_of_dict_lists_from_single_dict_list(
157+
prediction_service_client=prediction_service_client,
158+
endpoint_name=endpoint_name,
159+
dict_list=[instance],
160+
parameters=parameters,
161+
):
162+
if len(dict_list) > 1:
163+
raise ValueError(
164+
f"Expected to receive a single output, but got {dict_list}"
165+
)
166+
yield dict_list[0]

tests/system/aiplatform/test_language_models.py

+14
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ def test_text_generation(self):
4848
top_k=5,
4949
).text
5050

51+
def test_text_generation_streaming(self):
52+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
53+
54+
model = TextGenerationModel.from_pretrained("google/text-bison@001")
55+
56+
for response in model.predict_streaming(
57+
"What is the best recipe for banana bread? Recipe:",
58+
max_output_tokens=128,
59+
temperature=0,
60+
top_p=1,
61+
top_k=5,
62+
):
63+
assert response.text
64+
5165
def test_chat_on_chat_model(self):
5266
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
5367

tests/unit/aiplatform/test_language_models.py

+67
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud import storage
2929

3030
from google.cloud import aiplatform
31+
from google.cloud.aiplatform import _streaming_prediction
3132
from google.cloud.aiplatform import base
3233
from google.cloud.aiplatform import initializer
3334
from google.cloud.aiplatform.utils import gcs_utils
@@ -168,6 +169,34 @@
168169
1. Preheat oven to 350 degrees F (175 degrees C).""",
169170
}
170171

172+
_TEST_TEXT_GENERATION_PREDICTION_STREAMING = [
173+
{
174+
"content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.",
175+
},
176+
{
177+
"content": " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.",
178+
"safetyAttributes": {"blocked": False, "categories": None, "scores": None},
179+
},
180+
{
181+
"content": " 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45.",
182+
"citationMetadata": {
183+
"citations": [
184+
{
185+
"title": "THEATRUM ARITHMETICO-GEOMETRICUM",
186+
"publicationDate": "1727",
187+
"endIndex": 181,
188+
"startIndex": 12,
189+
}
190+
]
191+
},
192+
"safetyAttributes": {
193+
"blocked": True,
194+
"categories": ["Finance"],
195+
"scores": [0.1],
196+
},
197+
},
198+
]
199+
171200
_TEST_CHAT_GENERATION_PREDICTION1 = {
172201
"safetyAttributes": [
173202
{
@@ -1040,6 +1069,10 @@ class TestLanguageModels:
10401069
def setup_method(self):
10411070
reload(initializer)
10421071
reload(aiplatform)
1072+
aiplatform.init(
1073+
project=_TEST_PROJECT,
1074+
location=_TEST_LOCATION,
1075+
)
10431076

10441077
def teardown_method(self):
10451078
initializer.global_pool.shutdown(wait=True)
@@ -1165,6 +1198,40 @@ def test_text_generation_ga(self):
11651198
assert "topP" not in prediction_parameters
11661199
assert "topK" not in prediction_parameters
11671200

1201+
def test_text_generation_model_predict_streaming(self):
1202+
"""Tests the TextGenerationModel.predict_streaming method."""
1203+
with mock.patch.object(
1204+
target=model_garden_service_client.ModelGardenServiceClient,
1205+
attribute="get_publisher_model",
1206+
return_value=gca_publisher_model.PublisherModel(
1207+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1208+
),
1209+
):
1210+
model = language_models.TextGenerationModel.from_pretrained(
1211+
"text-bison@001"
1212+
)
1213+
1214+
response_generator = (
1215+
gca_prediction_service.StreamingPredictResponse(
1216+
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
1217+
)
1218+
for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING
1219+
)
1220+
1221+
with mock.patch.object(
1222+
target=prediction_service_client.PredictionServiceClient,
1223+
attribute="server_streaming_predict",
1224+
return_value=response_generator,
1225+
):
1226+
for response in model.predict_streaming(
1227+
"Count to 50",
1228+
max_output_tokens=1000,
1229+
temperature=0,
1230+
top_p=1,
1231+
top_k=5,
1232+
):
1233+
assert len(response.text) > 10
1234+
11681235
@pytest.mark.parametrize(
11691236
"job_spec",
11701237
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],

vertexai/language_models/_language_models.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
"""Classes for working with language models."""
1616

1717
import dataclasses
18-
from typing import Any, Dict, List, Optional, Sequence, Union
18+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
1919
import warnings
2020

2121
from google.cloud import aiplatform
22+
from google.cloud.aiplatform import _streaming_prediction
2223
from google.cloud.aiplatform import base
2324
from google.cloud.aiplatform import initializer as aiplatform_initializer
2425
from google.cloud.aiplatform import utils as aiplatform_utils
@@ -389,6 +390,70 @@ def _batch_predict(
389390
)
390391
return results
391392

393+
def predict_streaming(
394+
self,
395+
prompt: str,
396+
*,
397+
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
398+
temperature: Optional[float] = None,
399+
top_k: Optional[int] = None,
400+
top_p: Optional[float] = None,
401+
) -> Iterator[TextGenerationResponse]:
402+
"""Gets a streaming model response for a single prompt.
403+
404+
The result is a stream (generator) of partial responses.
405+
406+
Args:
407+
prompt: Question to ask the model.
408+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
409+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
410+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
411+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
412+
413+
Yields:
414+
A stream of `TextGenerationResponse` objects that contain partial
415+
responses produced by the model.
416+
"""
417+
prediction_service_client = self._endpoint._prediction_client
418+
# Note: "prompt", not "content" like in the non-streaming case. b/294462691
419+
instance = {"prompt": prompt}
420+
prediction_parameters = {}
421+
422+
if max_output_tokens:
423+
prediction_parameters["maxDecodeSteps"] = max_output_tokens
424+
425+
if temperature is not None:
426+
prediction_parameters["temperature"] = temperature
427+
428+
if top_p:
429+
prediction_parameters["topP"] = top_p
430+
431+
if top_k:
432+
prediction_parameters["topK"] = top_k
433+
434+
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
435+
prediction_service_client=prediction_service_client,
436+
endpoint_name=self._endpoint_name,
437+
instance=instance,
438+
parameters=prediction_parameters,
439+
):
440+
safety_attributes_dict = prediction_dict.get("safetyAttributes", {})
441+
prediction_obj = aiplatform.models.Prediction(
442+
predictions=[prediction_dict],
443+
deployed_model_id="",
444+
)
445+
yield TextGenerationResponse(
446+
text=prediction_dict["content"],
447+
_prediction_response=prediction_obj,
448+
is_blocked=safety_attributes_dict.get("blocked", False),
449+
safety_attributes=dict(
450+
zip(
451+
safety_attributes_dict.get("categories") or [],
452+
safety_attributes_dict.get("scores") or [],
453+
)
454+
),
455+
)
456+
392457

393458
class _ModelWithBatchPredict(_LanguageModel):
394459
"""Model that supports batch prediction."""

0 commit comments

Comments
 (0)