Skip to content

Commit 3a8348b

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support streaming prediction for code generation models
PiperOrigin-RevId: 558312759
1 parent 705e1ea commit 3a8348b

File tree

3 files changed

+132
-5
lines changed

3 files changed

+132
-5
lines changed

tests/system/aiplatform/test_language_models.py

+14
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
job_state as gca_job_state,
2323
)
2424
from tests.system.aiplatform import e2e_base
25+
from vertexai import language_models
2526
from vertexai.preview.language_models import (
2627
ChatModel,
2728
InputOutputTextPair,
@@ -251,3 +252,16 @@ def test_batch_prediction_for_textembedding(self):
251252
job.delete()
252253

253254
assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
255+
256+
def test_code_generation_streaming(self):
257+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
258+
259+
model = language_models.CodeGenerationModel.from_pretrained("code-bison@001")
260+
261+
for response in model.predict_streaming(
262+
prefix="def reverse_string(s):",
263+
suffix=" return s",
264+
max_output_tokens=128,
265+
temperature=0,
266+
):
267+
assert response.text

tests/unit/aiplatform/test_language_models.py

+33
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,39 @@ def test_code_completion(self):
20682068
assert "temperature" not in prediction_parameters
20692069
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens
20702070

2071+
def test_code_generation_model_predict_streaming(self):
2072+
"""Tests the TextGenerationModel.predict_streaming method."""
2073+
with mock.patch.object(
2074+
target=model_garden_service_client.ModelGardenServiceClient,
2075+
attribute="get_publisher_model",
2076+
return_value=gca_publisher_model.PublisherModel(
2077+
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
2078+
),
2079+
):
2080+
model = language_models.CodeGenerationModel.from_pretrained(
2081+
"code-bison@001"
2082+
)
2083+
2084+
response_generator = (
2085+
gca_prediction_service.StreamingPredictResponse(
2086+
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
2087+
)
2088+
for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING
2089+
)
2090+
2091+
with mock.patch.object(
2092+
target=prediction_service_client.PredictionServiceClient,
2093+
attribute="server_streaming_predict",
2094+
return_value=response_generator,
2095+
):
2096+
for response in model.predict_streaming(
2097+
prefix="def reverse_string(s):",
2098+
suffix=" return s",
2099+
max_output_tokens=1000,
2100+
temperature=0,
2101+
):
2102+
assert len(response.text) > 10
2103+
20712104
def test_text_embedding(self):
20722105
"""Tests the text embedding model."""
20732106
aiplatform.init(

vertexai/language_models/_language_models.py

+85-5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
5959
return f"publishers/google/models/{model_name}@{version}"
6060

6161

62+
@dataclasses.dataclass
63+
class _PredictionRequest:
64+
"""A single-instance prediction request."""
65+
instance: Dict[str, Any]
66+
parameters: Optional[Dict[str, Any]] = None
67+
68+
6269
class _LanguageModel(_model_garden_models._ModelGardenModel):
6370
"""_LanguageModel is a base class for all language models."""
6471

@@ -1250,15 +1257,15 @@ class CodeGenerationModel(_LanguageModel):
12501257
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
12511258
_DEFAULT_MAX_OUTPUT_TOKENS = 128
12521259

1253-
def predict(
1260+
def _create_prediction_request(
12541261
self,
12551262
prefix: str,
12561263
suffix: Optional[str] = None,
12571264
*,
12581265
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
12591266
temperature: Optional[float] = None,
1260-
) -> "TextGenerationResponse":
1261-
"""Gets model response for a single prompt.
1267+
) -> _PredictionRequest:
1268+
"""Creates a code generation prediction request.
12621269
12631270
Args:
12641271
prefix: Code before the current point.
@@ -1281,16 +1288,89 @@ def predict(
12811288
if max_output_tokens:
12821289
prediction_parameters["maxOutputTokens"] = max_output_tokens
12831290

1291+
return _PredictionRequest(instance=instance, parameters=prediction_parameters)
1292+
1293+
def predict(
1294+
self,
1295+
prefix: str,
1296+
suffix: Optional[str] = None,
1297+
*,
1298+
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
1299+
temperature: Optional[float] = None,
1300+
) -> "TextGenerationResponse":
1301+
"""Gets model response for a single prompt.
1302+
1303+
Args:
1304+
prefix: Code before the current point.
1305+
suffix: Code after the current point.
1306+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
1307+
temperature: Controls the randomness of predictions. Range: [0, 1].
1308+
1309+
Returns:
1310+
A `TextGenerationResponse` object that contains the text produced by the model.
1311+
"""
1312+
prediction_request = self._create_prediction_request(
1313+
prefix=prefix,
1314+
suffix=suffix,
1315+
max_output_tokens=max_output_tokens,
1316+
temperature=temperature,
1317+
)
1318+
12841319
prediction_response = self._endpoint.predict(
1285-
instances=[instance],
1286-
parameters=prediction_parameters,
1320+
instances=[prediction_request.instance],
1321+
parameters=prediction_request.parameters,
12871322
)
12881323

12891324
return TextGenerationResponse(
12901325
text=prediction_response.predictions[0]["content"],
12911326
_prediction_response=prediction_response,
12921327
)
12931328

1329+
def predict_streaming(
1330+
self,
1331+
prefix: str,
1332+
suffix: Optional[str] = None,
1333+
*,
1334+
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
1335+
temperature: Optional[float] = None,
1336+
) -> Iterator[TextGenerationResponse]:
1337+
"""Predicts the code based on previous code.
1338+
1339+
The result is a stream (generator) of partial responses.
1340+
1341+
Args:
1342+
prefix: Code before the current point.
1343+
suffix: Code after the current point.
1344+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
1345+
temperature: Controls the randomness of predictions. Range: [0, 1].
1346+
1347+
Yields:
1348+
A stream of `TextGenerationResponse` objects that contain partial
1349+
responses produced by the model.
1350+
"""
1351+
prediction_request = self._create_prediction_request(
1352+
prefix=prefix,
1353+
suffix=suffix,
1354+
max_output_tokens=max_output_tokens,
1355+
temperature=temperature,
1356+
)
1357+
1358+
prediction_service_client = self._endpoint._prediction_client
1359+
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
1360+
prediction_service_client=prediction_service_client,
1361+
endpoint_name=self._endpoint_name,
1362+
instance=prediction_request.instance,
1363+
parameters=prediction_request.parameters,
1364+
):
1365+
prediction_obj = aiplatform.models.Prediction(
1366+
predictions=[prediction_dict],
1367+
deployed_model_id="",
1368+
)
1369+
yield TextGenerationResponse(
1370+
text=prediction_dict["content"],
1371+
_prediction_response=prediction_obj,
1372+
)
1373+
12941374

12951375
class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
12961376
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE

0 commit comments

Comments
 (0)