Skip to content

Commit 587df74

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for multiple chat response candidates
PiperOrigin-RevId: 572100735
1 parent e76abd3 commit 587df74

File tree

2 files changed

+113
-22
lines changed

2 files changed

+113
-22
lines changed

tests/unit/aiplatform/test_language_models.py

+71
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,30 @@
238238
}
239239
],
240240
}
241+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION = {
242+
"safetyAttributes": [
243+
{
244+
"scores": [],
245+
"categories": [],
246+
"blocked": False,
247+
},
248+
{
249+
"scores": [0.1],
250+
"categories": ["Finance"],
251+
"blocked": True,
252+
},
253+
],
254+
"candidates": [
255+
{
256+
"author": "1",
257+
"content": "Chat response 2",
258+
},
259+
{
260+
"author": "1",
261+
"content": "",
262+
},
263+
],
264+
}
241265

242266
_TEST_CHAT_PREDICTION_STREAMING = [
243267
{
@@ -2076,6 +2100,53 @@ def test_chat_ga(self):
20762100
assert prediction_parameters["topP"] == message_top_p
20772101
assert prediction_parameters["stopSequences"] == message_stop_sequences
20782102

2103+
def test_chat_model_send_message_with_multiple_candidates(self):
2104+
"""Tests the chat generation model with multiple candidates."""
2105+
2106+
with mock.patch.object(
2107+
target=model_garden_service_client.ModelGardenServiceClient,
2108+
attribute="get_publisher_model",
2109+
return_value=gca_publisher_model.PublisherModel(
2110+
_CHAT_BISON_PUBLISHER_MODEL_DICT
2111+
),
2112+
) as mock_get_publisher_model:
2113+
model = language_models.ChatModel.from_pretrained("chat-bison@001")
2114+
2115+
mock_get_publisher_model.assert_called_once_with(
2116+
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
2117+
)
2118+
2119+
chat = model.start_chat()
2120+
2121+
gca_predict_response1 = gca_prediction_service.PredictResponse()
2122+
gca_predict_response1.predictions.append(
2123+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION
2124+
)
2125+
2126+
with mock.patch.object(
2127+
target=prediction_service_client.PredictionServiceClient,
2128+
attribute="predict",
2129+
return_value=gca_predict_response1,
2130+
):
2131+
message_text1 = "Are my favorite movies based on a book series?"
2132+
expected_response_candidates = (
2133+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION["candidates"]
2134+
)
2135+
expected_candidate_0 = expected_response_candidates[0]["content"]
2136+
expected_candidate_1 = expected_response_candidates[1]["content"]
2137+
2138+
response = chat.send_message(message_text1, candidate_count=2)
2139+
assert response.text == expected_candidate_0
2140+
assert len(response.candidates) == 2
2141+
assert response.candidates[0].text == expected_candidate_0
2142+
assert response.candidates[1].text == expected_candidate_1
2143+
2144+
assert len(chat.message_history) == 2
2145+
assert chat.message_history[0].author == chat.USER_AUTHOR
2146+
assert chat.message_history[0].content == message_text1
2147+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
2148+
assert chat.message_history[1].content == expected_candidate_0
2149+
20792150
def test_chat_model_send_message_streaming(self):
20802151
"""Tests the chat generation model."""
20812152
with mock.patch.object(

vertexai/language_models/_language_models.py

+42-22
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,7 @@ def _prepare_request(
16151615
top_k: Optional[int] = None,
16161616
top_p: Optional[float] = None,
16171617
stop_sequences: Optional[List[str]] = None,
1618+
candidate_count: Optional[int] = None,
16181619
) -> _PredictionRequest:
16191620
"""Prepares a request for the language model.
16201621
@@ -1629,6 +1630,7 @@ def _prepare_request(
16291630
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
16301631
Uses the value specified when calling `ChatModel.start_chat` by default.
16311632
stop_sequences: Customized stop sequences to stop the decoding process.
1633+
candidate_count: Number of candidates to return.
16321634
16331635
Returns:
16341636
A `_PredictionRequest` object.
@@ -1660,6 +1662,9 @@ def _prepare_request(
16601662
if stop_sequences:
16611663
prediction_parameters["stopSequences"] = stop_sequences
16621664

1665+
if candidate_count is not None:
1666+
prediction_parameters["candidateCount"] = candidate_count
1667+
16631668
message_structs = []
16641669
for past_message in self._message_history:
16651670
message_structs.append(
@@ -1697,8 +1702,7 @@ def _parse_chat_prediction_response(
16971702
cls,
16981703
prediction_response: aiplatform.models.Prediction,
16991704
prediction_idx: int = 0,
1700-
candidate_idx: int = 0,
1701-
) -> TextGenerationResponse:
1705+
) -> MultiCandidateTextGenerationResponse:
17021706
"""Parses prediction response for chat models.
17031707
17041708
Args:
@@ -1707,25 +1711,33 @@ def _parse_chat_prediction_response(
17071711
candidate_idx: Index of the candidate to parse.
17081712
17091713
Returns:
1710-
A `TextGenerationResponse` object.
1714+
A `MultiCandidateTextGenerationResponse` object.
17111715
"""
17121716
prediction = prediction_response.predictions[prediction_idx]
1713-
# ! Note: For chat models, the safetyAttributes is a list.
1714-
safety_attributes = prediction["safetyAttributes"][candidate_idx]
1715-
return TextGenerationResponse(
1716-
text=prediction["candidates"][candidate_idx]["content"]
1717-
if prediction.get("candidates")
1718-
else None,
1717+
candidate_count = len(prediction["candidates"])
1718+
candidates = []
1719+
for candidate_idx in range(candidate_count):
1720+
safety_attributes = prediction["safetyAttributes"][candidate_idx]
1721+
candidate_response = TextGenerationResponse(
1722+
text=prediction["candidates"][candidate_idx]["content"],
1723+
_prediction_response=prediction_response,
1724+
is_blocked=safety_attributes.get("blocked", False),
1725+
safety_attributes=dict(
1726+
zip(
1727+
# Unlike with normal prediction, in streaming prediction
1728+
# categories and scores can be None
1729+
safety_attributes.get("categories") or [],
1730+
safety_attributes.get("scores") or [],
1731+
)
1732+
),
1733+
)
1734+
candidates.append(candidate_response)
1735+
return MultiCandidateTextGenerationResponse(
1736+
text=candidates[0].text,
17191737
_prediction_response=prediction_response,
1720-
is_blocked=safety_attributes.get("blocked", False),
1721-
safety_attributes=dict(
1722-
zip(
1723-
# Unlike with normal prediction, in streaming prediction
1724-
# categories and scores can be None
1725-
safety_attributes.get("categories") or [],
1726-
safety_attributes.get("scores") or [],
1727-
)
1728-
),
1738+
is_blocked=candidates[0].is_blocked,
1739+
safety_attributes=candidates[0].safety_attributes,
1740+
candidates=candidates,
17291741
)
17301742

17311743
def send_message(
@@ -1737,7 +1749,8 @@ def send_message(
17371749
top_k: Optional[int] = None,
17381750
top_p: Optional[float] = None,
17391751
stop_sequences: Optional[List[str]] = None,
1740-
) -> "TextGenerationResponse":
1752+
candidate_count: Optional[int] = None,
1753+
) -> "MultiCandidateTextGenerationResponse":
17411754
"""Sends message to the language model and gets a response.
17421755
17431756
Args:
@@ -1751,9 +1764,11 @@ def send_message(
17511764
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
17521765
Uses the value specified when calling `ChatModel.start_chat` by default.
17531766
stop_sequences: Customized stop sequences to stop the decoding process.
1767+
candidate_count: Number of candidates to return.
17541768
17551769
Returns:
1756-
A `TextGenerationResponse` object that contains the text produced by the model.
1770+
A `MultiCandidateTextGenerationResponse` object that contains the
1771+
text produced by the model.
17571772
"""
17581773
prediction_request = self._prepare_request(
17591774
message=message,
@@ -1762,6 +1777,7 @@ def send_message(
17621777
top_k=top_k,
17631778
top_p=top_p,
17641779
stop_sequences=stop_sequences,
1780+
candidate_count=candidate_count,
17651781
)
17661782

17671783
prediction_response = self._model._endpoint.predict(
@@ -1791,7 +1807,8 @@ async def send_message_async(
17911807
top_k: Optional[int] = None,
17921808
top_p: Optional[float] = None,
17931809
stop_sequences: Optional[List[str]] = None,
1794-
) -> "TextGenerationResponse":
1810+
candidate_count: Optional[int] = None,
1811+
) -> "MultiCandidateTextGenerationResponse":
17951812
"""Asynchronously sends message to the language model and gets a response.
17961813
17971814
Args:
@@ -1805,9 +1822,11 @@ async def send_message_async(
18051822
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
18061823
Uses the value specified when calling `ChatModel.start_chat` by default.
18071824
stop_sequences: Customized stop sequences to stop the decoding process.
1825+
candidate_count: Number of candidates to return.
18081826
18091827
Returns:
1810-
A `TextGenerationResponse` object that contains the text produced by the model.
1828+
A `MultiCandidateTextGenerationResponse` object that contains
1829+
the text produced by the model.
18111830
"""
18121831
prediction_request = self._prepare_request(
18131832
message=message,
@@ -1816,6 +1835,7 @@ async def send_message_async(
18161835
top_k=top_k,
18171836
top_p=top_p,
18181837
stop_sequences=stop_sequences,
1838+
candidate_count=candidate_count,
18191839
)
18201840

18211841
prediction_response = await self._model._endpoint.predict_async(

0 commit comments

Comments
 (0)