Skip to content

Commit d62bb1b

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added stop_sequences parameter to streaming methods and CodeChatModel
PiperOrigin-RevId: 562915062
1 parent f8d43bb commit d62bb1b

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

tests/unit/aiplatform/test_language_models.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,7 @@ def test_text_generation_model_predict_streaming(self):
13041304
temperature=0.0,
13051305
top_p=1.0,
13061306
top_k=5,
1307+
stop_sequences=["# %%"],
13071308
):
13081309
assert len(response.text) > 10
13091310

@@ -1969,6 +1970,7 @@ def test_chat_model_send_message_streaming(self):
19691970
),
19701971
],
19711972
temperature=0.0,
1973+
stop_sequences=["\n"],
19721974
)
19731975

19741976
# Using list instead of a generator so that it can be reused.
@@ -1983,6 +1985,7 @@ def test_chat_model_send_message_streaming(self):
19831985
message_max_output_tokens = 200
19841986
message_top_k = 2
19851987
message_top_p = 0.2
1988+
message_stop_sequences = ["# %%"]
19861989

19871990
with mock.patch.object(
19881991
target=prediction_service_client.PredictionServiceClient,
@@ -1998,6 +2001,7 @@ def test_chat_model_send_message_streaming(self):
19982001
temperature=message_temperature,
19992002
top_k=message_top_k,
20002003
top_p=message_top_p,
2004+
stop_sequences=message_stop_sequences,
20012005
)
20022006
):
20032007
assert len(response.text) > 10
@@ -2036,6 +2040,7 @@ def test_code_chat(self):
20362040
code_chat = model.start_chat(
20372041
max_output_tokens=128,
20382042
temperature=0.2,
2043+
stop_sequences=["\n"],
20392044
)
20402045

20412046
gca_predict_response1 = gca_prediction_service.PredictResponse()
@@ -2075,12 +2080,15 @@ def test_code_chat(self):
20752080
# Validating the parameters
20762081
chat_temperature = 0.1
20772082
chat_max_output_tokens = 100
2083+
chat_stop_sequences = ["\n"]
20782084
message_temperature = 0.2
20792085
message_max_output_tokens = 200
2086+
message_stop_sequences = ["# %%"]
20802087

20812088
code_chat2 = model.start_chat(
20822089
temperature=chat_temperature,
20832090
max_output_tokens=chat_max_output_tokens,
2091+
stop_sequences=chat_stop_sequences,
20842092
)
20852093

20862094
gca_predict_response3 = gca_prediction_service.PredictResponse()
@@ -2097,15 +2105,18 @@ def test_code_chat(self):
20972105
prediction_parameters = mock_predict.call_args[1]["parameters"]
20982106
assert prediction_parameters["temperature"] == chat_temperature
20992107
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
2108+
assert prediction_parameters["stopSequences"] == chat_stop_sequences
21002109

21012110
code_chat2.send_message(
21022111
"Please help write a function to calculate the min of two numbers",
21032112
temperature=message_temperature,
21042113
max_output_tokens=message_max_output_tokens,
2114+
stop_sequences=message_stop_sequences,
21052115
)
21062116
prediction_parameters = mock_predict.call_args[1]["parameters"]
21072117
assert prediction_parameters["temperature"] == message_temperature
21082118
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
2119+
assert prediction_parameters["stopSequences"] == message_stop_sequences
21092120

21102121
def test_code_chat_model_send_message_streaming(self):
21112122
"""Tests the chat generation model."""
@@ -2122,7 +2133,7 @@ def test_code_chat_model_send_message_streaming(self):
21222133
):
21232134
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")
21242135

2125-
chat = model.start_chat(temperature=0.0)
2136+
chat = model.start_chat(temperature=0.0, stop_sequences=["\n"])
21262137

21272138
# Using list instead of a generator so that it can be reused.
21282139
response_generator = [

vertexai/language_models/_language_models.py

+24
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,7 @@ def predict_streaming(
734734
temperature: Optional[float] = None,
735735
top_k: Optional[int] = None,
736736
top_p: Optional[float] = None,
737+
stop_sequences: Optional[List[str]] = None,
737738
) -> Iterator[TextGenerationResponse]:
738739
"""Gets a streaming model response for a single prompt.
739740
@@ -745,6 +746,7 @@ def predict_streaming(
745746
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
746747
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
747748
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
749+
stop_sequences: Customized stop sequences to stop the decoding process.
748750
749751
Yields:
750752
A stream of `TextGenerationResponse` objects that contain partial
@@ -771,6 +773,9 @@ def predict_streaming(
771773
if top_k:
772774
prediction_parameters["topK"] = top_k
773775

776+
if stop_sequences:
777+
prediction_parameters["stopSequences"] = stop_sequences
778+
774779
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
775780
prediction_service_client=prediction_service_client,
776781
endpoint_name=self._endpoint_name,
@@ -1299,12 +1304,14 @@ def start_chat(
12991304
max_output_tokens: Optional[int] = None,
13001305
temperature: Optional[float] = None,
13011306
message_history: Optional[List[ChatMessage]] = None,
1307+
stop_sequences: Optional[List[str]] = None,
13021308
) -> "CodeChatSession":
13031309
"""Starts a chat session with the code chat model.
13041310
13051311
Args:
13061312
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
13071313
temperature: Controls the randomness of predictions. Range: [0, 1].
1314+
stop_sequences: Customized stop sequences to stop the decoding process.
13081315
13091316
Returns:
13101317
A `ChatSession` object.
@@ -1314,6 +1321,7 @@ def start_chat(
13141321
max_output_tokens=max_output_tokens,
13151322
temperature=temperature,
13161323
message_history=message_history,
1324+
stop_sequences=stop_sequences,
13171325
)
13181326

13191327

@@ -1541,6 +1549,7 @@ def send_message_streaming(
15411549
temperature: Optional[float] = None,
15421550
top_k: Optional[int] = None,
15431551
top_p: Optional[float] = None,
1552+
stop_sequences: Optional[List[str]] = None,
15441553
) -> Iterator[TextGenerationResponse]:
15451554
"""Sends message to the language model and gets a streamed response.
15461555
@@ -1556,6 +1565,8 @@ def send_message_streaming(
15561565
Uses the value specified when calling `ChatModel.start_chat` by default.
15571566
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
15581567
Uses the value specified when calling `ChatModel.start_chat` by default.
1568+
stop_sequences: Customized stop sequences to stop the decoding process.
1569+
Uses the value specified when calling `ChatModel.start_chat` by default.
15591570
15601571
Yields:
15611572
A stream of `TextGenerationResponse` objects that contain partial
@@ -1567,6 +1578,7 @@ def send_message_streaming(
15671578
temperature=temperature,
15681579
top_k=top_k,
15691580
top_p=top_p,
1581+
stop_sequences=stop_sequences,
15701582
)
15711583

15721584
prediction_service_client = self._model._endpoint._prediction_client
@@ -1644,12 +1656,14 @@ def __init__(
16441656
max_output_tokens: Optional[int] = None,
16451657
temperature: Optional[float] = None,
16461658
message_history: Optional[List[ChatMessage]] = None,
1659+
stop_sequences: Optional[List[str]] = None,
16471660
):
16481661
super().__init__(
16491662
model=model,
16501663
max_output_tokens=max_output_tokens,
16511664
temperature=temperature,
16521665
message_history=message_history,
1666+
stop_sequences=stop_sequences,
16531667
)
16541668

16551669
def send_message(
@@ -1658,6 +1672,7 @@ def send_message(
16581672
*,
16591673
max_output_tokens: Optional[int] = None,
16601674
temperature: Optional[float] = None,
1675+
stop_sequences: Optional[List[str]] = None,
16611676
) -> "TextGenerationResponse":
16621677
"""Sends message to the code chat model and gets a response.
16631678
@@ -1667,6 +1682,7 @@ def send_message(
16671682
Uses the value specified when calling `CodeChatModel.start_chat` by default.
16681683
temperature: Controls the randomness of predictions. Range: [0, 1].
16691684
Uses the value specified when calling `CodeChatModel.start_chat` by default.
1685+
stop_sequences: Customized stop sequences to stop the decoding process.
16701686
16711687
Returns:
16721688
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -1675,6 +1691,7 @@ def send_message(
16751691
message=message,
16761692
max_output_tokens=max_output_tokens,
16771693
temperature=temperature,
1694+
stop_sequences=stop_sequences,
16781695
)
16791696

16801697
def send_message_streaming(
@@ -1683,6 +1700,7 @@ def send_message_streaming(
16831700
*,
16841701
max_output_tokens: Optional[int] = None,
16851702
temperature: Optional[float] = None,
1703+
stop_sequences: Optional[List[str]] = None,
16861704
) -> Iterator[TextGenerationResponse]:
16871705
"""Sends message to the language model and gets a streamed response.
16881706
@@ -1694,6 +1712,8 @@ def send_message_streaming(
16941712
Uses the value specified when calling `ChatModel.start_chat` by default.
16951713
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
16961714
Uses the value specified when calling `ChatModel.start_chat` by default.
1715+
stop_sequences: Customized stop sequences to stop the decoding process.
1716+
Uses the value specified when calling `ChatModel.start_chat` by default.
16971717
16981718
Returns:
16991719
A stream of `TextGenerationResponse` objects that contain partial
@@ -1703,6 +1723,7 @@ def send_message_streaming(
17031723
message=message,
17041724
max_output_tokens=max_output_tokens,
17051725
temperature=temperature,
1726+
stop_sequences=stop_sequences,
17061727
)
17071728

17081729

@@ -1811,6 +1832,7 @@ def predict_streaming(
18111832
*,
18121833
max_output_tokens: Optional[int] = None,
18131834
temperature: Optional[float] = None,
1835+
stop_sequences: Optional[List[str]] = None,
18141836
) -> Iterator[TextGenerationResponse]:
18151837
"""Predicts the code based on previous code.
18161838
@@ -1821,6 +1843,7 @@ def predict_streaming(
18211843
suffix: Code after the current point.
18221844
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
18231845
temperature: Controls the randomness of predictions. Range: [0, 1].
1846+
stop_sequences: Customized stop sequences to stop the decoding process.
18241847
18251848
Yields:
18261849
A stream of `TextGenerationResponse` objects that contain partial
@@ -1831,6 +1854,7 @@ def predict_streaming(
18311854
suffix=suffix,
18321855
max_output_tokens=max_output_tokens,
18331856
temperature=temperature,
1857+
stop_sequences=stop_sequences,
18341858
)
18351859

18361860
prediction_service_client = self._endpoint._prediction_client

0 commit comments

Comments
 (0)