Skip to content

Commit bf0e20b

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Exposed the chat history
PiperOrigin-RevId: 542386162
1 parent 8abd9e4 commit bf0e20b

File tree

3 files changed

+84
-42
lines changed

3 files changed

+84
-42
lines changed

tests/system/aiplatform/test_language_models.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,24 @@ def test_chat_on_chat_model(self):
6767
temperature=0.0,
6868
)
6969

70-
assert chat.send_message("Are my favorite movies based on a book series?").text
71-
assert len(chat._history) == 1
72-
assert chat.send_message(
73-
"When where these books published?",
70+
message1 = "Are my favorite movies based on a book series?"
71+
response1 = chat.send_message(message1)
72+
assert response1.text
73+
assert len(chat.message_history) == 2
74+
assert chat.message_history[0].author == chat.USER_AUTHOR
75+
assert chat.message_history[0].content == message1
76+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
77+
78+
message2 = "When where these books published?"
79+
response2 = chat.send_message(
80+
message2,
7481
temperature=0.1,
75-
).text
76-
assert len(chat._history) == 2
82+
)
83+
assert response2.text
84+
assert len(chat.message_history) == 4
85+
assert chat.message_history[2].author == chat.USER_AUTHOR
86+
assert chat.message_history[2].content == message2
87+
assert chat.message_history[3].author == chat.MODEL_AUTHOR
7788

7889
def test_text_embedding(self):
7990
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

tests/unit/aiplatform/test_language_models.py

+24-19
Original file line numberDiff line numberDiff line change
@@ -758,14 +758,17 @@ def test_chat(self):
758758
attribute="predict",
759759
return_value=gca_predict_response1,
760760
):
761-
response = chat.send_message(
762-
"Are my favorite movies based on a book series?"
763-
)
764-
assert (
765-
response.text
766-
== _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0]["content"]
767-
)
768-
assert len(chat._history) == 1
761+
message_text1 = "Are my favorite movies based on a book series?"
762+
expected_response1 = _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0][
763+
"content"
764+
]
765+
response = chat.send_message(message_text1)
766+
assert response.text == expected_response1
767+
assert len(chat.message_history) == 2
768+
assert chat.message_history[0].author == chat.USER_AUTHOR
769+
assert chat.message_history[0].content == message_text1
770+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
771+
assert chat.message_history[1].content == expected_response1
769772

770773
gca_predict_response2 = gca_prediction_service.PredictResponse()
771774
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
@@ -775,15 +778,17 @@ def test_chat(self):
775778
attribute="predict",
776779
return_value=gca_predict_response2,
777780
):
778-
response = chat.send_message(
779-
"When where these books published?",
780-
temperature=0.1,
781-
)
782-
assert (
783-
response.text
784-
== _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0]["content"]
785-
)
786-
assert len(chat._history) == 2
781+
message_text2 = "When where these books published?"
782+
expected_response2 = _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0][
783+
"content"
784+
]
785+
response = chat.send_message(message_text2, temperature=0.1)
786+
assert response.text == expected_response2
787+
assert len(chat.message_history) == 4
788+
assert chat.message_history[2].author == chat.USER_AUTHOR
789+
assert chat.message_history[2].content == message_text2
790+
assert chat.message_history[3].author == chat.MODEL_AUTHOR
791+
assert chat.message_history[3].content == expected_response2
787792

788793
# Validating the parameters
789794
chat_temperature = 0.1
@@ -870,7 +875,7 @@ def test_code_chat(self):
870875
response.text
871876
== _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0]["content"]
872877
)
873-
assert len(code_chat._history) == 1
878+
assert len(code_chat.message_history) == 2
874879

875880
gca_predict_response2 = gca_prediction_service.PredictResponse()
876881
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
@@ -889,7 +894,7 @@ def test_code_chat(self):
889894
response.text
890895
== _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0]["content"]
891896
)
892-
assert len(code_chat._history) == 2
897+
assert len(code_chat.message_history) == 4
893898

894899
# Validating the parameters
895900
chat_temperature = 0.1

vertexai/language_models/_language_models.py

+43-17
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,19 @@ class InputOutputTextPair:
565565
output_text: str
566566

567567

568+
@dataclasses.dataclass
569+
class ChatMessage:
570+
"""A chat message.
571+
572+
Attributes:
573+
content: Content of the message.
574+
author: Author of the message.
575+
"""
576+
577+
content: str
578+
author: str
579+
580+
568581
class _ChatModelBase(_LanguageModel):
569582
"""_ChatModelBase is a base class for chat models."""
570583

@@ -579,6 +592,7 @@ def start_chat(
579592
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
580593
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
581594
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
595+
message_history: Optional[List[ChatMessage]] = None,
582596
) -> "ChatSession":
583597
"""Starts a chat session with the model.
584598
@@ -591,6 +605,7 @@ def start_chat(
591605
temperature: Controls the randomness of predictions. Range: [0, 1].
592606
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]
593607
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
608+
message_history: A list of previously sent and received messages.
594609
595610
Returns:
596611
A `ChatSession` object.
@@ -603,6 +618,7 @@ def start_chat(
603618
temperature=temperature,
604619
top_k=top_k,
605620
top_p=top_p,
621+
message_history=message_history,
606622
)
607623

608624

@@ -678,6 +694,9 @@ def start_chat(
678694
class _ChatSessionBase:
679695
"""_ChatSessionBase is a base class for all chat sessions."""
680696

697+
USER_AUTHOR = "user"
698+
MODEL_AUTHOR = "bot"
699+
681700
def __init__(
682701
self,
683702
model: _ChatModelBase,
@@ -688,16 +707,22 @@ def __init__(
688707
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
689708
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
690709
is_code_chat_session: bool = False,
710+
message_history: Optional[List[ChatMessage]] = None,
691711
):
692712
self._model = model
693713
self._context = context
694714
self._examples = examples
695-
self._history = []
696715
self._max_output_tokens = max_output_tokens
697716
self._temperature = temperature
698717
self._top_k = top_k
699718
self._top_p = top_p
700719
self._is_code_chat_session = is_code_chat_session
720+
self._message_history: List[ChatMessage] = message_history or []
721+
722+
@property
723+
def message_history(self) -> List[ChatMessage]:
724+
"""List of previous messages."""
725+
return self._message_history
701726

702727
def send_message(
703728
self,
@@ -737,29 +762,22 @@ def send_message(
737762
prediction_parameters["topP"] = top_p if top_p is not None else self._top_p
738763
prediction_parameters["topK"] = top_k if top_k is not None else self._top_k
739764

740-
messages = []
741-
for input_text, output_text in self._history:
742-
messages.append(
765+
message_structs = []
766+
for past_message in self._message_history:
767+
message_structs.append(
743768
{
744-
"author": "user",
745-
"content": input_text,
769+
"author": past_message.author,
770+
"content": past_message.content,
746771
}
747772
)
748-
messages.append(
749-
{
750-
"author": "bot",
751-
"content": output_text,
752-
}
753-
)
754-
755-
messages.append(
773+
message_structs.append(
756774
{
757-
"author": "user",
775+
"author": self.USER_AUTHOR,
758776
"content": message,
759777
}
760778
)
761779

762-
prediction_instance = {"messages": messages}
780+
prediction_instance = {"messages": message_structs}
763781
if not self._is_code_chat_session and self._context:
764782
prediction_instance["context"] = self._context
765783
if not self._is_code_chat_session and self._examples:
@@ -793,7 +811,13 @@ def send_message(
793811
)
794812
response_text = response_obj.text
795813

796-
self._history.append((message, response_text))
814+
self._message_history.append(
815+
ChatMessage(content=message, author=self.USER_AUTHOR)
816+
)
817+
self._message_history.append(
818+
ChatMessage(content=response_text, author=self.MODEL_AUTHOR)
819+
)
820+
797821
return response_obj
798822

799823

@@ -812,6 +836,7 @@ def __init__(
812836
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
813837
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
814838
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
839+
message_history: Optional[List[ChatMessage]] = None,
815840
):
816841
super().__init__(
817842
model=model,
@@ -821,6 +846,7 @@ def __init__(
821846
temperature=temperature,
822847
top_k=top_k,
823848
top_p=top_p,
849+
message_history=message_history,
824850
)
825851

826852

0 commit comments

Comments
 (0)