Skip to content

Commit 598d57d

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for multiple response candidates in code chat models
PiperOrigin-RevId: 573371030
1 parent 0c371a4 commit 598d57d

File tree

2 files changed

+63
-4
lines changed

2 files changed

+63
-4
lines changed

tests/unit/aiplatform/test_language_models.py

+51
Original file line numberDiff line numberDiff line change
@@ -2419,6 +2419,57 @@ def test_code_chat(self):
24192419
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
24202420
assert prediction_parameters["stopSequences"] == message_stop_sequences
24212421

2422+
def test_code_chat_model_send_message_with_multiple_candidates(self):
2423+
"""Tests the code chat model with multiple candidates."""
2424+
with mock.patch.object(
2425+
target=model_garden_service_client.ModelGardenServiceClient,
2426+
attribute="get_publisher_model",
2427+
return_value=gca_publisher_model.PublisherModel(
2428+
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
2429+
),
2430+
autospec=True,
2431+
):
2432+
model = language_models.CodeChatModel.from_pretrained(
2433+
"google/codechat-bison@001"
2434+
)
2435+
2436+
chat = model.start_chat()
2437+
2438+
gca_predict_response1 = gca_prediction_service.PredictResponse()
2439+
gca_predict_response1.predictions.append(
2440+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION
2441+
)
2442+
2443+
with mock.patch.object(
2444+
target=prediction_service_client.PredictionServiceClient,
2445+
attribute="predict",
2446+
return_value=gca_predict_response1,
2447+
autospec=True,
2448+
):
2449+
message_text1 = "Are my favorite movies based on a book series?"
2450+
expected_response_candidates = (
2451+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION["candidates"]
2452+
)
2453+
expected_candidate_0 = expected_response_candidates[0]["content"]
2454+
expected_candidate_1 = expected_response_candidates[1]["content"]
2455+
2456+
response = chat.send_message(
2457+
message=message_text1,
2458+
# candidate_count acts as a maximum number, not exact number.
2459+
candidate_count=7,
2460+
)
2461+
# The service can return a different number of candidates.
2462+
assert response.text == expected_candidate_0
2463+
assert len(response.candidates) == 2
2464+
assert response.candidates[0].text == expected_candidate_0
2465+
assert response.candidates[1].text == expected_candidate_1
2466+
2467+
assert len(chat.message_history) == 2
2468+
assert chat.message_history[0].author == chat.USER_AUTHOR
2469+
assert chat.message_history[0].content == message_text1
2470+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
2471+
assert chat.message_history[1].content == expected_candidate_0
2472+
24222473
def test_code_chat_model_send_message_streaming(self):
24232474
"""Tests the chat generation model."""
24242475
aiplatform.init(

vertexai/language_models/_language_models.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -2112,7 +2112,8 @@ def send_message(
21122112
max_output_tokens: Optional[int] = None,
21132113
temperature: Optional[float] = None,
21142114
stop_sequences: Optional[List[str]] = None,
2115-
) -> "TextGenerationResponse":
2115+
candidate_count: Optional[int] = None,
2116+
) -> "MultiCandidateTextGenerationResponse":
21162117
"""Sends message to the code chat model and gets a response.
21172118
21182119
Args:
@@ -2122,15 +2123,18 @@ def send_message(
21222123
temperature: Controls the randomness of predictions. Range: [0, 1].
21232124
Uses the value specified when calling `CodeChatModel.start_chat` by default.
21242125
stop_sequences: Customized stop sequences to stop the decoding process.
2126+
candidate_count: Number of candidates to return.
21252127
21262128
Returns:
2127-
A `TextGenerationResponse` object that contains the text produced by the model.
2129+
A `MultiCandidateTextGenerationResponse` object that contains the
2130+
text produced by the model.
21282131
"""
21292132
return super().send_message(
21302133
message=message,
21312134
max_output_tokens=max_output_tokens,
21322135
temperature=temperature,
21332136
stop_sequences=stop_sequences,
2137+
candidate_count=candidate_count,
21342138
)
21352139

21362140
async def send_message_async(
@@ -2139,7 +2143,8 @@ async def send_message_async(
21392143
*,
21402144
max_output_tokens: Optional[int] = None,
21412145
temperature: Optional[float] = None,
2142-
) -> "TextGenerationResponse":
2146+
candidate_count: Optional[int] = None,
2147+
) -> "MultiCandidateTextGenerationResponse":
21432148
"""Asynchronously sends message to the code chat model and gets a response.
21442149
21452150
Args:
@@ -2148,14 +2153,17 @@ async def send_message_async(
21482153
Uses the value specified when calling `CodeChatModel.start_chat` by default.
21492154
temperature: Controls the randomness of predictions. Range: [0, 1].
21502155
Uses the value specified when calling `CodeChatModel.start_chat` by default.
2156+
candidate_count: Number of candidates to return.
21512157
21522158
Returns:
2153-
A `TextGenerationResponse` object that contains the text produced by the model.
2159+
A `MultiCandidateTextGenerationResponse` object that contains the
2160+
text produced by the model.
21542161
"""
21552162
return super().send_message_async(
21562163
message=message,
21572164
max_output_tokens=max_output_tokens,
21582165
temperature=temperature,
2166+
candidate_count=candidate_count,
21592167
)
21602168

21612169
def send_message_streaming(

0 commit comments

Comments
 (0)