Skip to content

Commit a0d815d

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - Fixed parameters set in ChatModel.start_chat being ignored
PiperOrigin-RevId: 537204205
1 parent ed1f747 commit a0d815d

File tree

2 files changed

+75
-16
lines changed

2 files changed

+75
-16
lines changed

tests/unit/aiplatform/test_language_models.py

+45
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,51 @@ def test_chat(self):
638638
)
639639
assert len(chat._history) == 2
640640

641+
# Validating the parameters
642+
chat_temperature = 0.1
643+
chat_max_output_tokens = 100
644+
chat_top_k = 1
645+
chat_top_p = 0.1
646+
message_temperature = 0.2
647+
message_max_output_tokens = 200
648+
message_top_k = 2
649+
message_top_p = 0.2
650+
651+
chat2 = model.start_chat(
652+
temperature=chat_temperature,
653+
max_output_tokens=chat_max_output_tokens,
654+
top_k=chat_top_k,
655+
top_p=chat_top_p,
656+
)
657+
658+
gca_predict_response3 = gca_prediction_service.PredictResponse()
659+
gca_predict_response3.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1)
660+
661+
with mock.patch.object(
662+
target=prediction_service_client.PredictionServiceClient,
663+
attribute="predict",
664+
return_value=gca_predict_response3,
665+
) as mock_predict3:
666+
chat2.send_message("Are my favorite movies based on a book series?")
667+
prediction_parameters = mock_predict3.call_args[1]["parameters"]
668+
assert prediction_parameters["temperature"] == chat_temperature
669+
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
670+
assert prediction_parameters["topK"] == chat_top_k
671+
assert prediction_parameters["topP"] == chat_top_p
672+
673+
chat2.send_message(
674+
"Are my favorite movies based on a book series?",
675+
temperature=message_temperature,
676+
max_output_tokens=message_max_output_tokens,
677+
top_k=message_top_k,
678+
top_p=message_top_p,
679+
)
680+
prediction_parameters = mock_predict3.call_args[1]["parameters"]
681+
assert prediction_parameters["temperature"] == message_temperature
682+
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
683+
assert prediction_parameters["topK"] == message_top_k
684+
assert prediction_parameters["topP"] == message_top_p
685+
641686
def test_text_embedding(self):
642687
"""Tests the text embedding model."""
643688
aiplatform.init(

vertexai/language_models/_language_models.py

+30-16
Original file line numberDiff line numberDiff line change
@@ -460,19 +460,23 @@ def send_message(
460460
self,
461461
message: str,
462462
*,
463-
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
464-
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
465-
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
466-
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
463+
max_output_tokens: Optional[int] = None,
464+
temperature: Optional[float] = None,
465+
top_k: Optional[int] = None,
466+
top_p: Optional[float] = None,
467467
) -> "TextGenerationResponse":
468468
"""Sends message to the language model and gets a response.
469469
470470
Args:
471471
message: Message to send to the model
472472
max_output_tokens: Max length of the output text in tokens.
473+
Uses the value specified when calling `ChatModel.start_chat` by default.
473474
temperature: Controls the randomness of predictions. Range: [0, 1].
475+
Uses the value specified when calling `ChatModel.start_chat` by default.
474476
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
477+
Uses the value specified when calling `ChatModel.start_chat` by default.
475478
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
479+
Uses the value specified when calling `ChatModel.start_chat` by default.
476480
477481
Returns:
478482
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -484,10 +488,12 @@ def send_message(
484488

485489
response_obj = self._model.predict(
486490
prompt=new_history_text,
487-
max_output_tokens=max_output_tokens or self._max_output_tokens,
488-
temperature=temperature or self._temperature,
489-
top_k=top_k or self._top_k,
490-
top_p=top_p or self._top_p,
491+
max_output_tokens=max_output_tokens
492+
if max_output_tokens is not None
493+
else self._max_output_tokens,
494+
temperature=temperature if temperature is not None else self._temperature,
495+
top_k=top_k if top_k is not None else self._top_k,
496+
top_p=top_p if top_p is not None else self._top_p,
491497
)
492498
response_text = response_obj.text
493499

@@ -636,28 +642,36 @@ def send_message(
636642
self,
637643
message: str,
638644
*,
639-
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
640-
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
641-
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
642-
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
645+
max_output_tokens: Optional[int] = None,
646+
temperature: Optional[float] = None,
647+
top_k: Optional[int] = None,
648+
top_p: Optional[float] = None,
643649
) -> "TextGenerationResponse":
644650
"""Sends message to the language model and gets a response.
645651
646652
Args:
647653
message: Message to send to the model
648654
max_output_tokens: Max length of the output text in tokens.
655+
Uses the value specified when calling `ChatModel.start_chat` by default.
649656
temperature: Controls the randomness of predictions. Range: [0, 1].
657+
Uses the value specified when calling `ChatModel.start_chat` by default.
650658
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
659+
Uses the value specified when calling `ChatModel.start_chat` by default.
651660
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
661+
Uses the value specified when calling `ChatModel.start_chat` by default.
652662
653663
Returns:
654664
A `TextGenerationResponse` object that contains the text produced by the model.
655665
"""
656666
prediction_parameters = {
657-
"temperature": temperature,
658-
"maxDecodeSteps": max_output_tokens,
659-
"topP": top_p,
660-
"topK": top_k,
667+
"temperature": temperature
668+
if temperature is not None
669+
else self._temperature,
670+
"maxDecodeSteps": max_output_tokens
671+
if max_output_tokens is not None
672+
else self._max_output_tokens,
673+
"topP": top_p if top_p is not None else self._top_p,
674+
"topK": top_k if top_k is not None else self._top_k,
661675
}
662676
messages = []
663677
for input_text, output_text in self._history:

0 commit comments

Comments
 (0)