Skip to content

Commit 01989b1

Browse files
sararobcopybara-github
authored andcommitted
feat: LLM - Added count_tokens support to ChatModel (preview)
PiperOrigin-RevId: 575006811
1 parent eb6071f commit 01989b1

File tree

4 files changed

+264
-6
lines changed

4 files changed

+264
-6
lines changed

tests/system/aiplatform/test_language_models.py

+23
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,29 @@ def test_chat_on_chat_model(self):
159159
assert chat.message_history[2].content == message2
160160
assert chat.message_history[3].author == chat.MODEL_AUTHOR
161161

162+
def test_chat_model_preview_count_tokens(self):
163+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
164+
165+
chat_model = ChatModel.from_pretrained("google/chat-bison@001")
166+
167+
chat = chat_model.start_chat()
168+
169+
chat.send_message("What should I do today?")
170+
171+
response_with_history = chat.count_tokens("Any ideas?")
172+
173+
response_without_history = chat_model.start_chat().count_tokens(
174+
"What should I do today?"
175+
)
176+
177+
assert (
178+
response_with_history.total_tokens > response_without_history.total_tokens
179+
)
180+
assert (
181+
response_with_history.total_billable_characters
182+
> response_without_history.total_billable_characters
183+
)
184+
162185
@pytest.mark.asyncio
163186
async def test_chat_model_async(self):
164187
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

tests/unit/aiplatform/test_language_models.py

+78
Original file line numberDiff line numberDiff line change
@@ -2377,6 +2377,44 @@ def test_chat_model_send_message_streaming(self):
23772377
assert chat.message_history[2].content == message_text1
23782378
assert chat.message_history[3].author == chat.MODEL_AUTHOR
23792379

2380+
def test_chat_model_preview_count_tokens(self):
2381+
"""Tests the text generation model."""
2382+
aiplatform.init(
2383+
project=_TEST_PROJECT,
2384+
location=_TEST_LOCATION,
2385+
)
2386+
with mock.patch.object(
2387+
target=model_garden_service_client.ModelGardenServiceClient,
2388+
attribute="get_publisher_model",
2389+
return_value=gca_publisher_model.PublisherModel(
2390+
_CHAT_BISON_PUBLISHER_MODEL_DICT
2391+
),
2392+
):
2393+
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
2394+
2395+
chat = model.start_chat()
2396+
assert isinstance(chat, preview_language_models.ChatSession)
2397+
2398+
gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
2399+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
2400+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
2401+
"total_billable_characters"
2402+
],
2403+
)
2404+
2405+
with mock.patch.object(
2406+
target=prediction_service_client_v1beta1.PredictionServiceClient,
2407+
attribute="count_tokens",
2408+
return_value=gca_count_tokens_response,
2409+
):
2410+
response = chat.count_tokens("What is the best recipe for banana bread?")
2411+
2412+
assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
2413+
assert (
2414+
response.total_billable_characters
2415+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
2416+
)
2417+
23802418
def test_code_chat(self):
23812419
"""Tests the code chat model."""
23822420
aiplatform.init(
@@ -2577,6 +2615,46 @@ def test_code_chat_model_send_message_streaming(self):
25772615
assert chat.message_history[0].content == message_text1
25782616
assert chat.message_history[1].author == chat.MODEL_AUTHOR
25792617

2618+
def test_code_chat_model_preview_count_tokens(self):
2619+
"""Tests the text generation model."""
2620+
aiplatform.init(
2621+
project=_TEST_PROJECT,
2622+
location=_TEST_LOCATION,
2623+
)
2624+
with mock.patch.object(
2625+
target=model_garden_service_client.ModelGardenServiceClient,
2626+
attribute="get_publisher_model",
2627+
return_value=gca_publisher_model.PublisherModel(
2628+
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
2629+
),
2630+
):
2631+
model = preview_language_models.CodeChatModel.from_pretrained(
2632+
"codechat-bison@001"
2633+
)
2634+
2635+
chat = model.start_chat()
2636+
assert isinstance(chat, preview_language_models.CodeChatSession)
2637+
2638+
gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
2639+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
2640+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
2641+
"total_billable_characters"
2642+
],
2643+
)
2644+
2645+
with mock.patch.object(
2646+
target=prediction_service_client_v1beta1.PredictionServiceClient,
2647+
attribute="count_tokens",
2648+
return_value=gca_count_tokens_response,
2649+
):
2650+
response = chat.count_tokens("What is the best recipe for banana bread?")
2651+
2652+
assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
2653+
assert (
2654+
response.total_billable_characters
2655+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
2656+
)
2657+
25802658
def test_code_generation(self):
25812659
"""Tests code generation with the code generation model."""
25822660
aiplatform.init(

vertexai/language_models/_language_models.py

+159-6
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def tune_model(
222222
if eval_spec.evaluation_data:
223223
if isinstance(eval_spec.evaluation_data, str):
224224
if eval_spec.evaluation_data.startswith("gs://"):
225-
tuning_parameters["evaluation_data_uri"] = eval_spec.evaluation_data
225+
tuning_parameters[
226+
"evaluation_data_uri"
227+
] = eval_spec.evaluation_data
226228
else:
227229
raise ValueError("evaluation_data should be a GCS URI")
228230
else:
@@ -627,7 +629,7 @@ def count_tokens(
627629
) -> CountTokensResponse:
628630
"""Counts the tokens and billable characters for a given prompt.
629631
630-
Note: this does not make a request to the model, it only counts the tokens
632+
Note: this does not make a prediction request to the model, it only counts the tokens
631633
in the request.
632634
633635
Args:
@@ -802,7 +804,9 @@ def predict(
802804
parameters=prediction_request.parameters,
803805
)
804806

805-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
807+
return _parse_text_generation_model_multi_candidate_response(
808+
prediction_response
809+
)
806810

807811
async def predict_async(
808812
self,
@@ -844,7 +848,9 @@ async def predict_async(
844848
parameters=prediction_request.parameters,
845849
)
846850

847-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
851+
return _parse_text_generation_model_multi_candidate_response(
852+
prediction_response
853+
)
848854

849855
def predict_streaming(
850856
self,
@@ -1587,6 +1593,47 @@ class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
15871593

15881594
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
15891595

1596+
def start_chat(
1597+
self,
1598+
*,
1599+
context: Optional[str] = None,
1600+
examples: Optional[List[InputOutputTextPair]] = None,
1601+
max_output_tokens: Optional[int] = None,
1602+
temperature: Optional[float] = None,
1603+
top_k: Optional[int] = None,
1604+
top_p: Optional[float] = None,
1605+
message_history: Optional[List[ChatMessage]] = None,
1606+
stop_sequences: Optional[List[str]] = None,
1607+
) -> "_PreviewChatSession":
1608+
"""Starts a chat session with the model.
1609+
1610+
Args:
1611+
context: Context shapes how the model responds throughout the conversation.
1612+
For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1613+
examples: List of structured messages to the model to learn how to respond to the conversation.
1614+
A list of `InputOutputTextPair` objects.
1615+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1616+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1617+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1618+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1619+
message_history: A list of previously sent and received messages.
1620+
stop_sequences: Customized stop sequences to stop the decoding process.
1621+
1622+
Returns:
1623+
A `ChatSession` object.
1624+
"""
1625+
return _PreviewChatSession(
1626+
model=self,
1627+
context=context,
1628+
examples=examples,
1629+
max_output_tokens=max_output_tokens,
1630+
temperature=temperature,
1631+
top_k=top_k,
1632+
top_p=top_p,
1633+
message_history=message_history,
1634+
stop_sequences=stop_sequences,
1635+
)
1636+
15901637

15911638
class CodeChatModel(_ChatModelBase):
15921639
"""CodeChatModel represents a model that is capable of completing code.
@@ -1646,6 +1693,47 @@ class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin):
16461693

16471694
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
16481695

1696+
def start_chat(
1697+
self,
1698+
*,
1699+
context: Optional[str] = None,
1700+
examples: Optional[List[InputOutputTextPair]] = None,
1701+
max_output_tokens: Optional[int] = None,
1702+
temperature: Optional[float] = None,
1703+
top_k: Optional[int] = None,
1704+
top_p: Optional[float] = None,
1705+
message_history: Optional[List[ChatMessage]] = None,
1706+
stop_sequences: Optional[List[str]] = None,
1707+
) -> "_PreviewCodeChatSession":
1708+
"""Starts a chat session with the model.
1709+
1710+
Args:
1711+
context: Context shapes how the model responds throughout the conversation.
1712+
For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1713+
examples: List of structured messages to the model to learn how to respond to the conversation.
1714+
A list of `InputOutputTextPair` objects.
1715+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1716+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1717+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1718+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1719+
message_history: A list of previously sent and received messages.
1720+
stop_sequences: Customized stop sequences to stop the decoding process.
1721+
1722+
Returns:
1723+
A `ChatSession` object.
1724+
"""
1725+
return _PreviewCodeChatSession(
1726+
model=self,
1727+
context=context,
1728+
examples=examples,
1729+
max_output_tokens=max_output_tokens,
1730+
temperature=temperature,
1731+
top_k=top_k,
1732+
top_p=top_p,
1733+
message_history=message_history,
1734+
stop_sequences=stop_sequences,
1735+
)
1736+
16491737

16501738
class _ChatSessionBase:
16511739
"""_ChatSessionBase is a base class for all chat sessions."""
@@ -2071,6 +2159,67 @@ async def send_message_streaming_async(
20712159
)
20722160

20732161

2162+
class _ChatSessionBaseWithCountTokensMixin(_ChatSessionBase):
2163+
"""A mixin class for adding count_tokens to ChatSession."""
2164+
2165+
def count_tokens(
2166+
self,
2167+
message: str,
2168+
) -> CountTokensResponse:
2169+
"""Counts the tokens and billable characters for the provided chat message and any message history,
2170+
context, or examples set on the chat session.
2171+
2172+
If you've called `send_message()` in the current chat session before calling `count_tokens()`, the
2173+
response will include the total tokens and characters for the previously sent message and the one in the
2174+
`count_tokens()` request. To count the tokens for a single message, call `count_tokens()` right after
2175+
calling `start_chat()` before calling `send_message()`.
2176+
2177+
Note: this does not make a prediction request to the model, it only counts the tokens
2178+
in the request.
2179+
2180+
Examples::
2181+
2182+
model = ChatModel.from_pretrained("chat-bison@001")
2183+
chat_session = model.start_chat()
2184+
count_tokens_response = chat_session.count_tokens("How's it going?")
2185+
2186+
count_tokens_response.total_tokens
2187+
count_tokens_response.total_billable_characters
2188+
2189+
Args:
2190+
message (str):
2191+
Required. A chat message to count tokens or. For example: "How's it going?"
2192+
Returns:
2193+
A `CountTokensResponse` object that contains the number of tokens
2194+
in the text and the number of billable characters.
2195+
"""
2196+
2197+
count_tokens_request = self._prepare_request(message=message)
2198+
2199+
count_tokens_response = self._model._endpoint._prediction_client.select_version(
2200+
"v1beta1"
2201+
).count_tokens(
2202+
endpoint=self._model._endpoint_name,
2203+
instances=[count_tokens_request.instance],
2204+
)
2205+
2206+
return CountTokensResponse(
2207+
total_tokens=count_tokens_response.total_tokens,
2208+
total_billable_characters=count_tokens_response.total_billable_characters,
2209+
_count_tokens_response=count_tokens_response,
2210+
)
2211+
2212+
2213+
class _PreviewChatSession(_ChatSessionBaseWithCountTokensMixin):
2214+
2215+
__module__ = "vertexai.preview.language_models"
2216+
2217+
2218+
class _PreviewCodeChatSession(_ChatSessionBaseWithCountTokensMixin):
2219+
2220+
__module__ = "vertexai.preview.language_models"
2221+
2222+
20742223
class ChatSession(_ChatSessionBase):
20752224
"""ChatSession represents a chat session with a language model.
20762225
@@ -2361,7 +2510,9 @@ def predict(
23612510
instances=[prediction_request.instance],
23622511
parameters=prediction_request.parameters,
23632512
)
2364-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
2513+
return _parse_text_generation_model_multi_candidate_response(
2514+
prediction_response
2515+
)
23652516

23662517
async def predict_async(
23672518
self,
@@ -2400,7 +2551,9 @@ async def predict_async(
24002551
instances=[prediction_request.instance],
24012552
parameters=prediction_request.parameters,
24022553
)
2403-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
2554+
return _parse_text_generation_model_multi_candidate_response(
2555+
prediction_response
2556+
)
24042557

24052558
def predict_streaming(
24062559
self,

vertexai/preview/language_models.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from vertexai.language_models._language_models import (
1818
_PreviewChatModel,
19+
_PreviewChatSession,
1920
_PreviewCodeChatModel,
21+
_PreviewCodeChatSession,
2022
_PreviewCodeGenerationModel,
2123
_PreviewTextEmbeddingModel,
2224
_PreviewTextGenerationModel,
@@ -43,7 +45,9 @@
4345

4446

4547
ChatModel = _PreviewChatModel
48+
ChatSession = _PreviewChatSession
4649
CodeChatModel = _PreviewCodeChatModel
50+
CodeChatSession = _PreviewCodeChatSession
4751
CodeGenerationModel = _PreviewCodeGenerationModel
4852
TextGenerationModel = _PreviewTextGenerationModel
4953
TextEmbeddingModel = _PreviewTextEmbeddingModel

0 commit comments

Comments
 (0)