@@ -460,19 +460,23 @@ def send_message(
460
460
self ,
461
461
message : str ,
462
462
* ,
463
- max_output_tokens : int = TextGenerationModel . _DEFAULT_MAX_OUTPUT_TOKENS ,
464
- temperature : float = TextGenerationModel . _DEFAULT_TEMPERATURE ,
465
- top_k : int = TextGenerationModel . _DEFAULT_TOP_K ,
466
- top_p : float = TextGenerationModel . _DEFAULT_TOP_P ,
463
+ max_output_tokens : Optional [ int ] = None ,
464
+ temperature : Optional [ float ] = None ,
465
+ top_k : Optional [ int ] = None ,
466
+ top_p : Optional [ float ] = None ,
467
467
) -> "TextGenerationResponse" :
468
468
"""Sends message to the language model and gets a response.
469
469
470
470
Args:
471
471
message: Message to send to the model
472
472
max_output_tokens: Max length of the output text in tokens.
473
+ Uses the value specified when calling `ChatModel.start_chat` by default.
473
474
temperature: Controls the randomness of predictions. Range: [0, 1].
475
+ Uses the value specified when calling `ChatModel.start_chat` by default.
474
476
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
477
+ Uses the value specified when calling `ChatModel.start_chat` by default.
475
478
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
479
+ Uses the value specified when calling `ChatModel.start_chat` by default.
476
480
477
481
Returns:
478
482
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -484,10 +488,12 @@ def send_message(
484
488
485
489
response_obj = self ._model .predict (
486
490
prompt = new_history_text ,
487
- max_output_tokens = max_output_tokens or self ._max_output_tokens ,
488
- temperature = temperature or self ._temperature ,
489
- top_k = top_k or self ._top_k ,
490
- top_p = top_p or self ._top_p ,
491
+ max_output_tokens = max_output_tokens
492
+ if max_output_tokens is not None
493
+ else self ._max_output_tokens ,
494
+ temperature = temperature if temperature is not None else self ._temperature ,
495
+ top_k = top_k if top_k is not None else self ._top_k ,
496
+ top_p = top_p if top_p is not None else self ._top_p ,
491
497
)
492
498
response_text = response_obj .text
493
499
@@ -636,28 +642,36 @@ def send_message(
636
642
self ,
637
643
message : str ,
638
644
* ,
639
- max_output_tokens : int = TextGenerationModel . _DEFAULT_MAX_OUTPUT_TOKENS ,
640
- temperature : float = TextGenerationModel . _DEFAULT_TEMPERATURE ,
641
- top_k : int = TextGenerationModel . _DEFAULT_TOP_K ,
642
- top_p : float = TextGenerationModel . _DEFAULT_TOP_P ,
645
+ max_output_tokens : Optional [ int ] = None ,
646
+ temperature : Optional [ float ] = None ,
647
+ top_k : Optional [ int ] = None ,
648
+ top_p : Optional [ float ] = None ,
643
649
) -> "TextGenerationResponse" :
644
650
"""Sends message to the language model and gets a response.
645
651
646
652
Args:
647
653
message: Message to send to the model
648
654
max_output_tokens: Max length of the output text in tokens.
655
+ Uses the value specified when calling `ChatModel.start_chat` by default.
649
656
temperature: Controls the randomness of predictions. Range: [0, 1].
657
+ Uses the value specified when calling `ChatModel.start_chat` by default.
650
658
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
659
+ Uses the value specified when calling `ChatModel.start_chat` by default.
651
660
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
661
+ Uses the value specified when calling `ChatModel.start_chat` by default.
652
662
653
663
Returns:
654
664
A `TextGenerationResponse` object that contains the text produced by the model.
655
665
"""
656
666
prediction_parameters = {
657
- "temperature" : temperature ,
658
- "maxDecodeSteps" : max_output_tokens ,
659
- "topP" : top_p ,
660
- "topK" : top_k ,
667
+ "temperature" : temperature
668
+ if temperature is not None
669
+ else self ._temperature ,
670
+ "maxDecodeSteps" : max_output_tokens
671
+ if max_output_tokens is not None
672
+ else self ._max_output_tokens ,
673
+ "topP" : top_p if top_p is not None else self ._top_p ,
674
+ "topK" : top_k if top_k is not None else self ._top_k ,
661
675
}
662
676
messages = []
663
677
for input_text , output_text in self ._history :
0 commit comments