@@ -565,6 +565,19 @@ class InputOutputTextPair:
565
565
output_text : str
566
566
567
567
568
+ @dataclasses .dataclass
569
+ class ChatMessage :
570
+ """A chat message.
571
+
572
+ Attributes:
573
+ content: Content of the message.
574
+ author: Author of the message.
575
+ """
576
+
577
+ content : str
578
+ author : str
579
+
580
+
568
581
class _ChatModelBase (_LanguageModel ):
569
582
"""_ChatModelBase is a base class for chat models."""
570
583
@@ -579,6 +592,7 @@ def start_chat(
579
592
temperature : float = TextGenerationModel ._DEFAULT_TEMPERATURE ,
580
593
top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
581
594
top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
595
+ message_history : Optional [List [ChatMessage ]] = None ,
582
596
) -> "ChatSession" :
583
597
"""Starts a chat session with the model.
584
598
@@ -591,6 +605,7 @@ def start_chat(
591
605
temperature: Controls the randomness of predictions. Range: [0, 1].
592
606
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]
593
607
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
608
+ message_history: A list of previously sent and received messages.
594
609
595
610
Returns:
596
611
A `ChatSession` object.
@@ -603,6 +618,7 @@ def start_chat(
603
618
temperature = temperature ,
604
619
top_k = top_k ,
605
620
top_p = top_p ,
621
+ message_history = message_history ,
606
622
)
607
623
608
624
@@ -678,6 +694,9 @@ def start_chat(
678
694
class _ChatSessionBase :
679
695
"""_ChatSessionBase is a base class for all chat sessions."""
680
696
697
+ USER_AUTHOR = "user"
698
+ MODEL_AUTHOR = "bot"
699
+
681
700
def __init__ (
682
701
self ,
683
702
model : _ChatModelBase ,
@@ -688,16 +707,22 @@ def __init__(
688
707
top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
689
708
top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
690
709
is_code_chat_session : bool = False ,
710
+ message_history : Optional [List [ChatMessage ]] = None ,
691
711
):
692
712
self ._model = model
693
713
self ._context = context
694
714
self ._examples = examples
695
- self ._history = []
696
715
self ._max_output_tokens = max_output_tokens
697
716
self ._temperature = temperature
698
717
self ._top_k = top_k
699
718
self ._top_p = top_p
700
719
self ._is_code_chat_session = is_code_chat_session
720
+ self ._message_history : List [ChatMessage ] = message_history or []
721
+
722
+ @property
723
+ def message_history (self ) -> List [ChatMessage ]:
724
+ """List of previous messages."""
725
+ return self ._message_history
701
726
702
727
def send_message (
703
728
self ,
@@ -737,29 +762,22 @@ def send_message(
737
762
prediction_parameters ["topP" ] = top_p if top_p is not None else self ._top_p
738
763
prediction_parameters ["topK" ] = top_k if top_k is not None else self ._top_k
739
764
740
- messages = []
741
- for input_text , output_text in self ._history :
742
- messages .append (
765
+ message_structs = []
766
+ for past_message in self ._message_history :
767
+ message_structs .append (
743
768
{
744
- "author" : "user" ,
745
- "content" : input_text ,
769
+ "author" : past_message . author ,
770
+ "content" : past_message . content ,
746
771
}
747
772
)
748
- messages .append (
749
- {
750
- "author" : "bot" ,
751
- "content" : output_text ,
752
- }
753
- )
754
-
755
- messages .append (
773
+ message_structs .append (
756
774
{
757
- "author" : "user" ,
775
+ "author" : self . USER_AUTHOR ,
758
776
"content" : message ,
759
777
}
760
778
)
761
779
762
- prediction_instance = {"messages" : messages }
780
+ prediction_instance = {"messages" : message_structs }
763
781
if not self ._is_code_chat_session and self ._context :
764
782
prediction_instance ["context" ] = self ._context
765
783
if not self ._is_code_chat_session and self ._examples :
@@ -793,7 +811,13 @@ def send_message(
793
811
)
794
812
response_text = response_obj .text
795
813
796
- self ._history .append ((message , response_text ))
814
+ self ._message_history .append (
815
+ ChatMessage (content = message , author = self .USER_AUTHOR )
816
+ )
817
+ self ._message_history .append (
818
+ ChatMessage (content = response_text , author = self .MODEL_AUTHOR )
819
+ )
820
+
797
821
return response_obj
798
822
799
823
@@ -812,6 +836,7 @@ def __init__(
812
836
temperature : float = TextGenerationModel ._DEFAULT_TEMPERATURE ,
813
837
top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
814
838
top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
839
+ message_history : Optional [List [ChatMessage ]] = None ,
815
840
):
816
841
super ().__init__ (
817
842
model = model ,
@@ -821,6 +846,7 @@ def __init__(
821
846
temperature = temperature ,
822
847
top_k = top_k ,
823
848
top_p = top_p ,
849
+ message_history = message_history ,
824
850
)
825
851
826
852
0 commit comments