Skip to content

Commit 6f7ea84

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for stop_sequences in inference
PiperOrigin-RevId: 558618706
1 parent 226ab8b commit 6f7ea84

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

tests/system/aiplatform/test_language_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def test_text_generation(self):
5151
temperature=0,
5252
top_p=1,
5353
top_k=5,
54+
stop_sequences=["# %%"],
5455
).text
5556

5657
def test_text_generation_streaming(self):
@@ -84,6 +85,7 @@ def test_chat_on_chat_model(self):
8485
),
8586
],
8687
temperature=0.0,
88+
stop_sequences=["# %%"],
8789
)
8890

8991
message1 = "Are my favorite movies based on a book series?"

tests/unit/aiplatform/test_language_models.py

+11
Original file line numberDiff line numberDiff line change
@@ -1237,13 +1237,15 @@ def test_text_generation_ga(self):
12371237
temperature=0,
12381238
top_p=1,
12391239
top_k=5,
1240+
stop_sequences=["\n"],
12401241
)
12411242

12421243
prediction_parameters = mock_predict.call_args[1]["parameters"]
12431244
assert prediction_parameters["maxDecodeSteps"] == 128
12441245
assert prediction_parameters["temperature"] == 0
12451246
assert prediction_parameters["topP"] == 1
12461247
assert prediction_parameters["topK"] == 5
1248+
assert prediction_parameters["stopSequences"] == ["\n"]
12471249
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
12481250

12491251
# Validating that unspecified parameters are not passed to the model
@@ -1798,16 +1800,19 @@ def test_chat_ga(self):
17981800
chat_max_output_tokens = 100
17991801
chat_top_k = 1
18001802
chat_top_p = 0.1
1803+
stop_sequences = ["\n"]
18011804
message_temperature = 0.2
18021805
message_max_output_tokens = 200
18031806
message_top_k = 2
18041807
message_top_p = 0.2
1808+
message_stop_sequences = ["# %%"]
18051809

18061810
chat2 = model.start_chat(
18071811
temperature=chat_temperature,
18081812
max_output_tokens=chat_max_output_tokens,
18091813
top_k=chat_top_k,
18101814
top_p=chat_top_p,
1815+
stop_sequences=stop_sequences,
18111816
)
18121817

18131818
gca_predict_response3 = gca_prediction_service.PredictResponse()
@@ -1824,19 +1829,22 @@ def test_chat_ga(self):
18241829
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
18251830
assert prediction_parameters["topK"] == chat_top_k
18261831
assert prediction_parameters["topP"] == chat_top_p
1832+
assert prediction_parameters["stopSequences"] == stop_sequences
18271833

18281834
chat2.send_message(
18291835
"Are my favorite movies based on a book series?",
18301836
temperature=message_temperature,
18311837
max_output_tokens=message_max_output_tokens,
18321838
top_k=message_top_k,
18331839
top_p=message_top_p,
1840+
stop_sequences=message_stop_sequences,
18341841
)
18351842
prediction_parameters = mock_predict3.call_args[1]["parameters"]
18361843
assert prediction_parameters["temperature"] == message_temperature
18371844
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
18381845
assert prediction_parameters["topK"] == message_top_k
18391846
assert prediction_parameters["topP"] == message_top_p
1847+
assert prediction_parameters["stopSequences"] == message_stop_sequences
18401848

18411849
def test_chat_model_send_message_streaming(self):
18421850
"""Tests the chat generation model."""
@@ -2102,6 +2110,7 @@ def test_code_generation(self):
21022110
default_max_output_tokens = (
21032111
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
21042112
)
2113+
stop_sequences = ["\n"]
21052114

21062115
with mock.patch.object(
21072116
target=prediction_service_client.PredictionServiceClient,
@@ -2112,10 +2121,12 @@ def test_code_generation(self):
21122121
prefix="Write a function that checks if a year is a leap year.",
21132122
max_output_tokens=predict_max_output_tokens,
21142123
temperature=predict_temperature,
2124+
stop_sequences=stop_sequences,
21152125
)
21162126
prediction_parameters = mock_predict.call_args[1]["parameters"]
21172127
assert prediction_parameters["temperature"] == predict_temperature
21182128
assert prediction_parameters["maxOutputTokens"] == predict_max_output_tokens
2129+
assert prediction_parameters["stopSequences"] == stop_sequences
21192130

21202131
model.predict(
21212132
prefix="Write a function that checks if a year is a leap year.",

vertexai/language_models/_language_models.py

+33
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ def predict(
636636
temperature: Optional[float] = None,
637637
top_k: Optional[int] = None,
638638
top_p: Optional[float] = None,
639+
stop_sequences: Optional[List[str]] = None,
639640
) -> "TextGenerationResponse":
640641
"""Gets model response for a single prompt.
641642
@@ -645,6 +646,7 @@ def predict(
645646
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
646647
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
647648
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
649+
stop_sequences: Customized stop sequences to stop the decoding process.
648650
649651
Returns:
650652
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -656,6 +658,7 @@ def predict(
656658
temperature=temperature,
657659
top_k=top_k,
658660
top_p=top_p,
661+
stop_sequences=stop_sequences,
659662
)[0]
660663

661664
def _batch_predict(
@@ -665,6 +668,7 @@ def _batch_predict(
665668
temperature: Optional[float] = None,
666669
top_k: Optional[int] = None,
667670
top_p: Optional[float] = None,
671+
stop_sequences: Optional[List[str]] = None,
668672
) -> List["TextGenerationResponse"]:
669673
"""Gets model response for a single prompt.
670674
@@ -674,6 +678,7 @@ def _batch_predict(
674678
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
675679
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
676680
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
681+
stop_sequences: Customized stop sequences to stop the decoding process.
677682
678683
Returns:
679684
A list of `TextGenerationResponse` objects that contain the texts produced by the model.
@@ -693,6 +698,9 @@ def _batch_predict(
693698
if top_k:
694699
prediction_parameters["topK"] = top_k
695700

701+
if stop_sequences:
702+
prediction_parameters["stopSequences"] = stop_sequences
703+
696704
prediction_response = self._endpoint.predict(
697705
instances=instances,
698706
parameters=prediction_parameters,
@@ -1165,6 +1173,7 @@ def start_chat(
11651173
top_k: Optional[int] = None,
11661174
top_p: Optional[float] = None,
11671175
message_history: Optional[List[ChatMessage]] = None,
1176+
stop_sequences: Optional[List[str]] = None,
11681177
) -> "ChatSession":
11691178
"""Starts a chat session with the model.
11701179
@@ -1178,6 +1187,7 @@ def start_chat(
11781187
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
11791188
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
11801189
message_history: A list of previously sent and received messages.
1190+
stop_sequences: Customized stop sequences to stop the decoding process.
11811191
11821192
Returns:
11831193
A `ChatSession` object.
@@ -1191,6 +1201,7 @@ def start_chat(
11911201
top_k=top_k,
11921202
top_p=top_p,
11931203
message_history=message_history,
1204+
stop_sequences=stop_sequences,
11941205
)
11951206

11961207

@@ -1291,6 +1302,7 @@ def __init__(
12911302
top_k: Optional[int] = None,
12921303
top_p: Optional[float] = None,
12931304
message_history: Optional[List[ChatMessage]] = None,
1305+
stop_sequences: Optional[List[str]] = None,
12941306
):
12951307
self._model = model
12961308
self._context = context
@@ -1300,6 +1312,7 @@ def __init__(
13001312
self._top_k = top_k
13011313
self._top_p = top_p
13021314
self._message_history: List[ChatMessage] = message_history or []
1315+
self._stop_sequences = stop_sequences
13031316

13041317
@property
13051318
def message_history(self) -> List[ChatMessage]:
@@ -1314,6 +1327,7 @@ def _prepare_request(
13141327
temperature: Optional[float] = None,
13151328
top_k: Optional[int] = None,
13161329
top_p: Optional[float] = None,
1330+
stop_sequences: Optional[List[str]] = None,
13171331
) -> _PredictionRequest:
13181332
"""Prepares a request for the language model.
13191333
@@ -1327,6 +1341,7 @@ def _prepare_request(
13271341
Uses the value specified when calling `ChatModel.start_chat` by default.
13281342
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
13291343
Uses the value specified when calling `ChatModel.start_chat` by default.
1344+
stop_sequences: Customized stop sequences to stop the decoding process.
13301345
13311346
Returns:
13321347
A `_PredictionRequest` object.
@@ -1350,6 +1365,10 @@ def _prepare_request(
13501365
if top_k:
13511366
prediction_parameters["topK"] = top_k
13521367

1368+
stop_sequences = stop_sequences or self._stop_sequences
1369+
if stop_sequences:
1370+
prediction_parameters["stopSequences"] = stop_sequences
1371+
13531372
message_structs = []
13541373
for past_message in self._message_history:
13551374
message_structs.append(
@@ -1426,6 +1445,7 @@ def send_message(
14261445
temperature: Optional[float] = None,
14271446
top_k: Optional[int] = None,
14281447
top_p: Optional[float] = None,
1448+
stop_sequences: Optional[List[str]] = None,
14291449
) -> "TextGenerationResponse":
14301450
"""Sends message to the language model and gets a response.
14311451
@@ -1439,6 +1459,7 @@ def send_message(
14391459
Uses the value specified when calling `ChatModel.start_chat` by default.
14401460
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
14411461
Uses the value specified when calling `ChatModel.start_chat` by default.
1462+
stop_sequences: Customized stop sequences to stop the decoding process.
14421463
14431464
Returns:
14441465
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -1449,6 +1470,7 @@ def send_message(
14491470
temperature=temperature,
14501471
top_k=top_k,
14511472
top_p=top_p,
1473+
stop_sequences=stop_sequences,
14521474
)
14531475

14541476
prediction_response = self._model._endpoint.predict(
@@ -1553,6 +1575,7 @@ def __init__(
15531575
top_k: Optional[int] = None,
15541576
top_p: Optional[float] = None,
15551577
message_history: Optional[List[ChatMessage]] = None,
1578+
stop_sequences: Optional[List[str]] = None,
15561579
):
15571580
super().__init__(
15581581
model=model,
@@ -1563,6 +1586,7 @@ def __init__(
15631586
top_k=top_k,
15641587
top_p=top_p,
15651588
message_history=message_history,
1589+
stop_sequences=stop_sequences,
15661590
)
15671591

15681592

@@ -1669,6 +1693,7 @@ def _create_prediction_request(
16691693
*,
16701694
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
16711695
temperature: Optional[float] = None,
1696+
stop_sequences: Optional[List[str]] = None,
16721697
) -> _PredictionRequest:
16731698
"""Creates a code generation prediction request.
16741699
@@ -1677,6 +1702,8 @@ def _create_prediction_request(
16771702
suffix: Code after the current point.
16781703
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
16791704
temperature: Controls the randomness of predictions. Range: [0, 1].
1705+
stop_sequences: Customized stop sequences to stop the decoding process.
1706+
16801707
16811708
Returns:
16821709
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -1693,6 +1720,9 @@ def _create_prediction_request(
16931720
if max_output_tokens:
16941721
prediction_parameters["maxOutputTokens"] = max_output_tokens
16951722

1723+
if stop_sequences:
1724+
prediction_parameters["stopSequences"] = stop_sequences
1725+
16961726
return _PredictionRequest(instance=instance, parameters=prediction_parameters)
16971727

16981728
def predict(
@@ -1702,6 +1732,7 @@ def predict(
17021732
*,
17031733
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
17041734
temperature: Optional[float] = None,
1735+
stop_sequences: Optional[List[str]] = None,
17051736
) -> "TextGenerationResponse":
17061737
"""Gets model response for a single prompt.
17071738
@@ -1710,6 +1741,7 @@ def predict(
17101741
suffix: Code after the current point.
17111742
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
17121743
temperature: Controls the randomness of predictions. Range: [0, 1].
1744+
stop_sequences: Customized stop sequences to stop the decoding process.
17131745
17141746
Returns:
17151747
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -1719,6 +1751,7 @@ def predict(
17191751
suffix=suffix,
17201752
max_output_tokens=max_output_tokens,
17211753
temperature=temperature,
1754+
stop_sequences=stop_sequences,
17221755
)
17231756

17241757
prediction_response = self._endpoint.predict(

0 commit comments

Comments
 (0)