Skip to content

Commit 0c371a4

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for multiple response candidates in code generation models
PiperOrigin-RevId: 573357986
1 parent 760a025 commit 0c371a4

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

tests/unit/aiplatform/test_language_models.py

+40
Original file line numberDiff line numberDiff line change
@@ -2540,6 +2540,46 @@ def test_code_generation(self):
25402540
assert "temperature" not in prediction_parameters
25412541
assert "maxOutputTokens" not in prediction_parameters
25422542

2543+
def test_code_generation_multiple_candidates(self):
2544+
"""Tests the code generation model with multiple candidates."""
2545+
with mock.patch.object(
2546+
target=model_garden_service_client.ModelGardenServiceClient,
2547+
attribute="get_publisher_model",
2548+
return_value=gca_publisher_model.PublisherModel(
2549+
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
2550+
),
2551+
autospec=True,
2552+
):
2553+
model = language_models.CodeGenerationModel.from_pretrained(
2554+
"code-bison@001"
2555+
)
2556+
2557+
gca_predict_response = gca_prediction_service.PredictResponse()
2558+
# Discrepancy between the number of `instances` and the number of `predictions`
2559+
# is a violation of the prediction service invariant, but the service does this.
2560+
gca_predict_response.predictions.append(_TEST_CODE_GENERATION_PREDICTION)
2561+
gca_predict_response.predictions.append(_TEST_CODE_GENERATION_PREDICTION)
2562+
with mock.patch.object(
2563+
target=prediction_service_client.PredictionServiceClient,
2564+
attribute="predict",
2565+
return_value=gca_predict_response,
2566+
autospec=True,
2567+
) as mock_predict:
2568+
response = model.predict(
2569+
prefix="Write a function that checks if a year is a leap year.",
2570+
# candidate_count acts as a maximum number, not exact number.
2571+
candidate_count=7,
2572+
)
2573+
prediction_parameters = mock_predict.call_args[1]["parameters"]
2574+
assert prediction_parameters["candidateCount"] == 7
2575+
2576+
assert response.text == _TEST_CODE_GENERATION_PREDICTION["content"]
2577+
# The service can return a different number of candidates.
2578+
assert len(response.candidates) == 2
2579+
assert (
2580+
response.candidates[0].text == _TEST_CODE_GENERATION_PREDICTION["content"]
2581+
)
2582+
25432583
def test_code_completion(self):
25442584
"""Tests code completion with the code generation model."""
25452585
aiplatform.init(

vertexai/language_models/_language_models.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -2254,6 +2254,7 @@ def _create_prediction_request(
22542254
max_output_tokens: Optional[int] = None,
22552255
temperature: Optional[float] = None,
22562256
stop_sequences: Optional[List[str]] = None,
2257+
candidate_count: Optional[int] = None,
22572258
) -> _PredictionRequest:
22582259
"""Creates a code generation prediction request.
22592260
@@ -2263,7 +2264,7 @@ def _create_prediction_request(
22632264
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
22642265
temperature: Controls the randomness of predictions. Range: [0, 1].
22652266
stop_sequences: Customized stop sequences to stop the decoding process.
2266-
2267+
candidate_count: Number of response candidates to return.
22672268
22682269
Returns:
22692270
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -2285,6 +2286,9 @@ def _create_prediction_request(
22852286
if stop_sequences:
22862287
prediction_parameters["stopSequences"] = stop_sequences
22872288

2289+
if candidate_count is not None:
2290+
prediction_parameters["candidateCount"] = candidate_count
2291+
22882292
return _PredictionRequest(instance=instance, parameters=prediction_parameters)
22892293

22902294
def predict(
@@ -2295,6 +2299,7 @@ def predict(
22952299
max_output_tokens: Optional[int] = None,
22962300
temperature: Optional[float] = None,
22972301
stop_sequences: Optional[List[str]] = None,
2302+
candidate_count: Optional[int] = None,
22982303
) -> "TextGenerationResponse":
22992304
"""Gets model response for a single prompt.
23002305
@@ -2304,23 +2309,26 @@ def predict(
23042309
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
23052310
temperature: Controls the randomness of predictions. Range: [0, 1].
23062311
stop_sequences: Customized stop sequences to stop the decoding process.
2312+
candidate_count: Number of response candidates to return.
23072313
23082314
Returns:
2309-
A `TextGenerationResponse` object that contains the text produced by the model.
2315+
A `MultiCandidateTextGenerationResponse` object that contains the
2316+
text produced by the model.
23102317
"""
23112318
prediction_request = self._create_prediction_request(
23122319
prefix=prefix,
23132320
suffix=suffix,
23142321
max_output_tokens=max_output_tokens,
23152322
temperature=temperature,
23162323
stop_sequences=stop_sequences,
2324+
candidate_count=candidate_count,
23172325
)
23182326

23192327
prediction_response = self._endpoint.predict(
23202328
instances=[prediction_request.instance],
23212329
parameters=prediction_request.parameters,
23222330
)
2323-
return _parse_text_generation_model_response(prediction_response)
2331+
return _parse_text_generation_model_multi_candidate_response(prediction_response)
23242332

23252333
async def predict_async(
23262334
self,
@@ -2330,6 +2338,7 @@ async def predict_async(
23302338
max_output_tokens: Optional[int] = None,
23312339
temperature: Optional[float] = None,
23322340
stop_sequences: Optional[List[str]] = None,
2341+
candidate_count: Optional[int] = None,
23332342
) -> "TextGenerationResponse":
23342343
"""Asynchronously gets model response for a single prompt.
23352344
@@ -2339,23 +2348,26 @@ async def predict_async(
23392348
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
23402349
temperature: Controls the randomness of predictions. Range: [0, 1].
23412350
stop_sequences: Customized stop sequences to stop the decoding process.
2351+
candidate_count: Number of response candidates to return.
23422352
23432353
Returns:
2344-
A `TextGenerationResponse` object that contains the text produced by the model.
2354+
A `MultiCandidateTextGenerationResponse` object that contains the
2355+
text produced by the model.
23452356
"""
23462357
prediction_request = self._create_prediction_request(
23472358
prefix=prefix,
23482359
suffix=suffix,
23492360
max_output_tokens=max_output_tokens,
23502361
temperature=temperature,
23512362
stop_sequences=stop_sequences,
2363+
candidate_count=candidate_count,
23522364
)
23532365

23542366
prediction_response = await self._endpoint.predict_async(
23552367
instances=[prediction_request.instance],
23562368
parameters=prediction_request.parameters,
23572369
)
2358-
return _parse_text_generation_model_response(prediction_response)
2370+
return _parse_text_generation_model_multi_candidate_response(prediction_response)
23592371

23602372
def predict_streaming(
23612373
self,

0 commit comments

Comments
 (0)