Skip to content

Commit 7bf7634

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - Exported the ChatMessage class
PiperOrigin-RevId: 544541329
1 parent 459ba86 commit 7bf7634

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

tests/unit/aiplatform/test_language_models.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,16 @@ def test_chat(self):
756756
output_text="Ned likes watching movies.",
757757
),
758758
],
759+
message_history=[
760+
preview_language_models.ChatMessage(
761+
author=preview_language_models.ChatSession.USER_AUTHOR,
762+
content="Question 1?",
763+
),
764+
preview_language_models.ChatMessage(
765+
author=preview_language_models.ChatSession.MODEL_AUTHOR,
766+
content="Answer 1.",
767+
),
768+
],
759769
temperature=0.0,
760770
)
761771

@@ -773,11 +783,11 @@ def test_chat(self):
773783
]
774784
response = chat.send_message(message_text1)
775785
assert response.text == expected_response1
776-
assert len(chat.message_history) == 2
777-
assert chat.message_history[0].author == chat.USER_AUTHOR
778-
assert chat.message_history[0].content == message_text1
779-
assert chat.message_history[1].author == chat.MODEL_AUTHOR
780-
assert chat.message_history[1].content == expected_response1
786+
assert len(chat.message_history) == 4
787+
assert chat.message_history[2].author == chat.USER_AUTHOR
788+
assert chat.message_history[2].content == message_text1
789+
assert chat.message_history[3].author == chat.MODEL_AUTHOR
790+
assert chat.message_history[3].content == expected_response1
781791

782792
gca_predict_response2 = gca_prediction_service.PredictResponse()
783793
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
@@ -793,11 +803,11 @@ def test_chat(self):
793803
]
794804
response = chat.send_message(message_text2, temperature=0.1)
795805
assert response.text == expected_response2
796-
assert len(chat.message_history) == 4
797-
assert chat.message_history[2].author == chat.USER_AUTHOR
798-
assert chat.message_history[2].content == message_text2
799-
assert chat.message_history[3].author == chat.MODEL_AUTHOR
800-
assert chat.message_history[3].content == expected_response2
806+
assert len(chat.message_history) == 6
807+
assert chat.message_history[4].author == chat.USER_AUTHOR
808+
assert chat.message_history[4].content == message_text2
809+
assert chat.message_history[5].author == chat.MODEL_AUTHOR
810+
assert chat.message_history[5].content == expected_response2
801811

802812
# Validating the parameters
803813
chat_temperature = 0.1

vertexai/preview/language_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vertexai.language_models._language_models import (
1818
_PreviewTextEmbeddingModel,
1919
_PreviewTextGenerationModel,
20+
ChatMessage,
2021
ChatModel,
2122
ChatSession,
2223
CodeChatModel,
@@ -31,6 +32,7 @@
3132
TextEmbeddingModel = _PreviewTextEmbeddingModel
3233

3334
__all__ = [
35+
"ChatMessage",
3436
"ChatModel",
3537
"ChatSession",
3638
"CodeChatModel",

0 commit comments

Comments
 (0)