Skip to content

Commit 22aa26d

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Released the Chat models to GA
PiperOrigin-RevId: 546475152
1 parent 52d0267 commit 22aa26d

File tree

4 files changed

+147
-2
lines changed

4 files changed

+147
-2
lines changed

tests/unit/aiplatform/test_language_models.py

+134-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
"name": "publishers/google/models/chat-bison",
8989
"version_id": "001",
9090
"open_source_category": "PROPRIETARY",
91-
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
91+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
9292
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
9393
"predict_schemata": {
9494
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml",
@@ -792,6 +792,139 @@ def test_chat(self):
792792
gca_predict_response2 = gca_prediction_service.PredictResponse()
793793
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
794794

795+
with mock.patch.object(
796+
target=prediction_service_client.PredictionServiceClient,
797+
attribute="predict",
798+
return_value=gca_predict_response2,
799+
):
800+
message_text2 = "When were these books published?"
801+
expected_response2 = _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0][
802+
"content"
803+
]
804+
response = chat.send_message(message_text2, temperature=0.1)
805+
assert response.text == expected_response2
806+
assert len(chat.message_history) == 6
807+
assert chat.message_history[4].author == chat.USER_AUTHOR
808+
assert chat.message_history[4].content == message_text2
809+
assert chat.message_history[5].author == chat.MODEL_AUTHOR
810+
assert chat.message_history[5].content == expected_response2
811+
812+
# Validating the parameters
813+
chat_temperature = 0.1
814+
chat_max_output_tokens = 100
815+
chat_top_k = 1
816+
chat_top_p = 0.1
817+
message_temperature = 0.2
818+
message_max_output_tokens = 200
819+
message_top_k = 2
820+
message_top_p = 0.2
821+
822+
chat2 = model.start_chat(
823+
temperature=chat_temperature,
824+
max_output_tokens=chat_max_output_tokens,
825+
top_k=chat_top_k,
826+
top_p=chat_top_p,
827+
)
828+
829+
gca_predict_response3 = gca_prediction_service.PredictResponse()
830+
gca_predict_response3.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1)
831+
832+
with mock.patch.object(
833+
target=prediction_service_client.PredictionServiceClient,
834+
attribute="predict",
835+
return_value=gca_predict_response3,
836+
) as mock_predict3:
837+
chat2.send_message("Are my favorite movies based on a book series?")
838+
prediction_parameters = mock_predict3.call_args[1]["parameters"]
839+
assert prediction_parameters["temperature"] == chat_temperature
840+
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
841+
assert prediction_parameters["topK"] == chat_top_k
842+
assert prediction_parameters["topP"] == chat_top_p
843+
844+
chat2.send_message(
845+
"Are my favorite movies based on a book series?",
846+
temperature=message_temperature,
847+
max_output_tokens=message_max_output_tokens,
848+
top_k=message_top_k,
849+
top_p=message_top_p,
850+
)
851+
prediction_parameters = mock_predict3.call_args[1]["parameters"]
852+
assert prediction_parameters["temperature"] == message_temperature
853+
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
854+
assert prediction_parameters["topK"] == message_top_k
855+
assert prediction_parameters["topP"] == message_top_p
856+
857+
def test_chat_ga(self):
858+
"""Tests the chat generation model."""
859+
aiplatform.init(
860+
project=_TEST_PROJECT,
861+
location=_TEST_LOCATION,
862+
)
863+
with mock.patch.object(
864+
target=model_garden_service_client.ModelGardenServiceClient,
865+
attribute="get_publisher_model",
866+
return_value=gca_publisher_model.PublisherModel(
867+
_CHAT_BISON_PUBLISHER_MODEL_DICT
868+
),
869+
) as mock_get_publisher_model:
870+
model = language_models.ChatModel.from_pretrained("chat-bison@001")
871+
872+
mock_get_publisher_model.assert_called_once_with(
873+
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
874+
)
875+
876+
chat = model.start_chat(
877+
context="""
878+
My name is Ned.
879+
You are my personal assistant.
880+
My favorite movies are Lord of the Rings and Hobbit.
881+
""",
882+
examples=[
883+
language_models.InputOutputTextPair(
884+
input_text="Who do you work for?",
885+
output_text="I work for Ned.",
886+
),
887+
language_models.InputOutputTextPair(
888+
input_text="What do I like?",
889+
output_text="Ned likes watching movies.",
890+
),
891+
],
892+
message_history=[
893+
language_models.ChatMessage(
894+
author=preview_language_models.ChatSession.USER_AUTHOR,
895+
content="Question 1?",
896+
),
897+
language_models.ChatMessage(
898+
author=preview_language_models.ChatSession.MODEL_AUTHOR,
899+
content="Answer 1.",
900+
),
901+
],
902+
temperature=0.0,
903+
)
904+
905+
gca_predict_response1 = gca_prediction_service.PredictResponse()
906+
gca_predict_response1.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1)
907+
908+
with mock.patch.object(
909+
target=prediction_service_client.PredictionServiceClient,
910+
attribute="predict",
911+
return_value=gca_predict_response1,
912+
):
913+
message_text1 = "Are my favorite movies based on a book series?"
914+
expected_response1 = _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0][
915+
"content"
916+
]
917+
response = chat.send_message(message_text1)
918+
assert response.text == expected_response1
919+
assert len(chat.message_history) == 4
920+
assert chat.message_history[2].author == chat.USER_AUTHOR
921+
assert chat.message_history[2].content == message_text1
922+
assert chat.message_history[3].author == chat.MODEL_AUTHOR
923+
assert chat.message_history[3].content == expected_response1
924+
925+
gca_predict_response2 = gca_prediction_service.PredictResponse()
926+
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
927+
795928
with mock.patch.object(
796929
target=prediction_service_client.PredictionServiceClient,
797930
attribute="predict",

vertexai/language_models/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"""Classes for working with language models."""
1616

1717
from vertexai.language_models._language_models import (
18+
ChatMessage,
19+
ChatModel,
20+
ChatSession,
1821
CodeChatModel,
1922
CodeChatSession,
2023
CodeGenerationModel,
@@ -26,6 +29,9 @@
2629
)
2730

2831
__all__ = [
32+
"ChatMessage",
33+
"ChatModel",
34+
"ChatSession",
2935
"CodeChatModel",
3036
"CodeChatSession",
3137
"CodeGenerationModel",

vertexai/language_models/_language_models.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ class ChatMessage:
584584
class _ChatModelBase(_LanguageModel):
585585
"""_ChatModelBase is a base class for chat models."""
586586

587-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
587+
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
588588

589589
def start_chat(
590590
self,
@@ -653,6 +653,10 @@ class ChatModel(_ChatModelBase):
653653
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
654654

655655

656+
class _PreviewChatModel(ChatModel):
657+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
658+
659+
656660
class CodeChatModel(_ChatModelBase):
657661
"""CodeChatModel represents a model that is capable of completing code.
658662

vertexai/preview/language_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Classes for working with language models."""
1616

1717
from vertexai.language_models._language_models import (
18+
_PreviewChatModel,
1819
_PreviewTextEmbeddingModel,
1920
_PreviewTextGenerationModel,
2021
ChatMessage,
@@ -28,6 +29,7 @@
2829
TextGenerationResponse,
2930
)
3031

32+
ChatModel = _PreviewChatModel
3133
TextGenerationModel = _PreviewTextGenerationModel
3234
TextEmbeddingModel = _PreviewTextEmbeddingModel
3335

0 commit comments

Comments
 (0)