Skip to content

Commit 6a2f2aa

Browse files
sararobcopybara-github
authored andcommitted
feat: LLM - Added the count_tokens method to the preview TextGenerationModel and TextEmbeddingModel classes
PiperOrigin-RevId: 570108703
1 parent 69a67f2 commit 6a2f2aa

File tree

4 files changed

+192
-6
lines changed

4 files changed

+192
-6
lines changed

tests/system/aiplatform/test_language_models.py

+12
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ def test_text_generation(self):
6060
stop_sequences=["# %%"],
6161
).text
6262

63+
def test_text_generation_preview_count_tokens(self):
64+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
65+
66+
model = preview_language_models.TextGenerationModel.from_pretrained(
67+
"google/text-bison@001"
68+
)
69+
70+
response = model.count_tokens(["How are you doing?"])
71+
72+
assert response.total_tokens
73+
assert response.total_billable_characters
74+
6375
@pytest.mark.asyncio
6476
async def test_text_generation_model_predict_async(self):
6577
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

tests/unit/aiplatform/test_language_models.py

+90
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@
5858
model as gca_model,
5959
)
6060

61+
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
62+
client as prediction_service_client_v1beta1,
63+
)
64+
from google.cloud.aiplatform_v1beta1.types import (
65+
prediction_service as gca_prediction_service_v1beta1,
66+
)
67+
6168
import vertexai
6269
from vertexai.preview import (
6370
language_models as preview_language_models,
@@ -306,6 +313,11 @@ def reverse_string_2(s):""",
306313
}
307314
}
308315

316+
_TEST_COUNT_TOKENS_RESPONSE = {
317+
"total_tokens": 5,
318+
"total_billable_characters": 25,
319+
}
320+
309321

310322
_TEST_TEXT_BISON_TRAINING_DF = pd.DataFrame(
311323
{
@@ -1206,6 +1218,43 @@ def test_text_generation(self):
12061218
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
12071219
)
12081220

1221+
def test_text_generation_preview_count_tokens(self):
1222+
"""Tests the text generation model."""
1223+
aiplatform.init(
1224+
project=_TEST_PROJECT,
1225+
location=_TEST_LOCATION,
1226+
)
1227+
with mock.patch.object(
1228+
target=model_garden_service_client.ModelGardenServiceClient,
1229+
attribute="get_publisher_model",
1230+
return_value=gca_publisher_model.PublisherModel(
1231+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1232+
),
1233+
):
1234+
model = preview_language_models.TextGenerationModel.from_pretrained(
1235+
"text-bison@001"
1236+
)
1237+
1238+
gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
1239+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
1240+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
1241+
"total_billable_characters"
1242+
],
1243+
)
1244+
1245+
with mock.patch.object(
1246+
target=prediction_service_client_v1beta1.PredictionServiceClient,
1247+
attribute="count_tokens",
1248+
return_value=gca_count_tokens_response,
1249+
):
1250+
response = model.count_tokens(["What is the best recipe for banana bread?"])
1251+
1252+
assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
1253+
assert (
1254+
response.total_billable_characters
1255+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
1256+
)
1257+
12091258
def test_text_generation_ga(self):
12101259
"""Tests the text generation model."""
12111260
aiplatform.init(
@@ -2469,6 +2518,47 @@ def test_text_embedding(self):
24692518
== expected_embedding["statistics"]["truncated"]
24702519
)
24712520

2521+
def test_text_embedding_preview_count_tokens(self):
2522+
"""Tests the text embedding model."""
2523+
aiplatform.init(
2524+
project=_TEST_PROJECT,
2525+
location=_TEST_LOCATION,
2526+
)
2527+
with mock.patch.object(
2528+
target=model_garden_service_client.ModelGardenServiceClient,
2529+
attribute="get_publisher_model",
2530+
return_value=gca_publisher_model.PublisherModel(
2531+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
2532+
),
2533+
):
2534+
model = preview_language_models.TextEmbeddingModel.from_pretrained(
2535+
"textembedding-gecko@001"
2536+
)
2537+
2538+
gca_count_tokens_response = (
2539+
gca_prediction_service_v1beta1.CountTokensResponse(
2540+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
2541+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
2542+
"total_billable_characters"
2543+
],
2544+
)
2545+
)
2546+
2547+
with mock.patch.object(
2548+
target=prediction_service_client_v1beta1.PredictionServiceClient,
2549+
attribute="count_tokens",
2550+
return_value=gca_count_tokens_response,
2551+
):
2552+
response = model.count_tokens(["What is life?"])
2553+
2554+
assert (
2555+
response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
2556+
)
2557+
assert (
2558+
response.total_billable_characters
2559+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
2560+
)
2561+
24722562
def test_text_embedding_ga(self):
24732563
"""Tests the text embedding model."""
24742564
aiplatform.init(

vertexai/language_models/_language_models.py

+88-6
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,15 @@ def _model_resource_name(self) -> str:
9292
@dataclasses.dataclass
9393
class _PredictionRequest:
9494
"""A single-instance prediction request."""
95+
9596
instance: Dict[str, Any]
9697
parameters: Optional[Dict[str, Any]] = None
9798

9899

99100
@dataclasses.dataclass
100101
class _MultiInstancePredictionRequest:
101102
"""A multi-instance prediction request."""
103+
102104
instances: List[Dict[str, Any]]
103105
parameters: Optional[Dict[str, Any]] = None
104106

@@ -573,6 +575,62 @@ def tune_model(
573575
return job
574576

575577

578+
@dataclasses.dataclass
579+
class CountTokensResponse:
580+
"""The response from a count_tokens request.
581+
Attributes:
582+
total_tokens (int):
583+
The total number of tokens counted across all
584+
instances passed to the request.
585+
total_billable_characters (int):
586+
The total number of billable characters
587+
counted across all instances from the request.
588+
"""
589+
590+
total_tokens: int
591+
total_billable_characters: int
592+
_count_tokens_response: Any
593+
594+
595+
class _CountTokensMixin(_LanguageModel):
596+
"""Mixin for models that support the CountTokens API"""
597+
598+
def count_tokens(
599+
self,
600+
prompts: List[str],
601+
) -> CountTokensResponse:
602+
"""Counts the tokens and billable characters for a given prompt.
603+
604+
Note: this does not make a request to the model, it only counts the tokens
605+
in the request.
606+
607+
Args:
608+
prompts (List[str]):
609+
Required. A list of prompts to ask the model. For example: ["What should I do today?", "How's it going?"]
610+
611+
Returns:
612+
A `CountTokensResponse` object that contains the number of tokens
613+
in the text and the number of billable characters.
614+
"""
615+
instances = []
616+
617+
for prompt in prompts:
618+
instances.append({"content": prompt})
619+
620+
count_tokens_response = self._endpoint._prediction_client.select_version(
621+
"v1beta1"
622+
).count_tokens(
623+
endpoint=self._endpoint_name,
624+
instances=instances,
625+
)
626+
627+
return CountTokensResponse(
628+
total_tokens=count_tokens_response.total_tokens,
629+
total_billable_characters=count_tokens_response.total_billable_characters,
630+
_count_tokens_response=count_tokens_response,
631+
)
632+
633+
576634
@dataclasses.dataclass
577635
class TuningEvaluationSpec:
578636
"""Specification for model evaluation to perform during tuning.
@@ -587,6 +645,7 @@ class TuningEvaluationSpec:
587645
tensorboard: Vertex Tensorboard where to write the evaluation metrics.
588646
The Tensorboard must be in the same location as the tuning job.
589647
"""
648+
590649
__module__ = "vertexai.language_models"
591650

592651
evaluation_data: str
@@ -605,6 +664,7 @@ class TextGenerationResponse:
605664
Learn more about the safety attributes here:
606665
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
607666
"""
667+
608668
__module__ = "vertexai.language_models"
609669

610670
text: str
@@ -761,7 +821,9 @@ def predict_streaming(
761821
)
762822

763823
prediction_service_client = self._endpoint._prediction_client
764-
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
824+
for (
825+
prediction_dict
826+
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
765827
prediction_service_client=prediction_service_client,
766828
endpoint_name=self._endpoint_name,
767829
instance=prediction_request.instance,
@@ -955,6 +1017,7 @@ class _PreviewTextGenerationModel(
9551017
_PreviewTunableTextModelMixin,
9561018
_PreviewModelWithBatchPredict,
9571019
_evaluatable_language_models._EvaluatableLanguageModel,
1020+
_CountTokensMixin,
9581021
):
9591022
# Do not add docstring so that it's inherited from the base class.
9601023
__name__ = "TextGenerationModel"
@@ -1094,6 +1157,7 @@ class TextEmbeddingInput:
10941157
Specifies that the embeddings will be used for clustering.
10951158
title: Optional identifier of the text content.
10961159
"""
1160+
10971161
__module__ = "vertexai.language_models"
10981162

10991163
text: str
@@ -1113,6 +1177,7 @@ class TextEmbeddingModel(_LanguageModel):
11131177
vector = embedding.values
11141178
print(len(vector))
11151179
"""
1180+
11161181
__module__ = "vertexai.language_models"
11171182

11181183
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
@@ -1173,7 +1238,8 @@ def _parse_text_embedding_response(
11731238
_prediction_response=prediction_response,
11741239
)
11751240

1176-
def get_embeddings(self,
1241+
def get_embeddings(
1242+
self,
11771243
texts: List[Union[str, TextEmbeddingInput]],
11781244
*,
11791245
auto_truncate: bool = True,
@@ -1207,7 +1273,8 @@ def get_embeddings(self,
12071273

12081274
return results
12091275

1210-
async def get_embeddings_async(self,
1276+
async def get_embeddings_async(
1277+
self,
12111278
texts: List[Union[str, TextEmbeddingInput]],
12121279
*,
12131280
auto_truncate: bool = True,
@@ -1242,7 +1309,9 @@ async def get_embeddings_async(self,
12421309
return results
12431310

12441311

1245-
class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
1312+
class _PreviewTextEmbeddingModel(
1313+
TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin
1314+
):
12461315
__name__ = "TextEmbeddingModel"
12471316
__module__ = "vertexai.preview.language_models"
12481317

@@ -1252,6 +1321,7 @@ class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
12521321
@dataclasses.dataclass
12531322
class TextEmbeddingStatistics:
12541323
"""Text embedding statistics."""
1324+
12551325
__module__ = "vertexai.language_models"
12561326

12571327
token_count: int
@@ -1261,6 +1331,7 @@ class TextEmbeddingStatistics:
12611331
@dataclasses.dataclass
12621332
class TextEmbedding:
12631333
"""Text embedding vector and statistics."""
1334+
12641335
__module__ = "vertexai.language_models"
12651336

12661337
values: List[float]
@@ -1271,6 +1342,7 @@ class TextEmbedding:
12711342
@dataclasses.dataclass
12721343
class InputOutputTextPair:
12731344
"""InputOutputTextPair represents a pair of input and output texts."""
1345+
12741346
__module__ = "vertexai.language_models"
12751347

12761348
input_text: str
@@ -1285,6 +1357,7 @@ class ChatMessage:
12851357
content: Content of the message.
12861358
author: Author of the message.
12871359
"""
1360+
12881361
__module__ = "vertexai.language_models"
12891362

12901363
content: str
@@ -1362,6 +1435,7 @@ class ChatModel(_ChatModelBase, _TunableChatModelMixin):
13621435
13631436
chat.send_message("Do you know any cool events this weekend?")
13641437
"""
1438+
13651439
__module__ = "vertexai.language_models"
13661440

13671441
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
@@ -1388,6 +1462,7 @@ class CodeChatModel(_ChatModelBase):
13881462
13891463
code_chat.send_message("Please help write a function to calculate the min of two numbers")
13901464
"""
1465+
13911466
__module__ = "vertexai.language_models"
13921467

13931468
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml"
@@ -1739,7 +1814,9 @@ def send_message_streaming(
17391814

17401815
full_response_text = ""
17411816

1742-
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
1817+
for (
1818+
prediction_dict
1819+
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
17431820
prediction_service_client=prediction_service_client,
17441821
endpoint_name=self._model._endpoint_name,
17451822
instance=prediction_request.instance,
@@ -1770,6 +1847,7 @@ class ChatSession(_ChatSessionBase):
17701847
17711848
Within a chat session, the model keeps context and remembers the previous conversation.
17721849
"""
1850+
17731851
__module__ = "vertexai.language_models"
17741852

17751853
def __init__(
@@ -1802,6 +1880,7 @@ class CodeChatSession(_ChatSessionBase):
18021880
18031881
Within a code chat session, the model keeps context and remembers the previous converstion.
18041882
"""
1883+
18051884
__module__ = "vertexai.language_models"
18061885

18071886
def __init__(
@@ -1924,6 +2003,7 @@ class CodeGenerationModel(_LanguageModel):
19242003
prefix="def reverse_string(s):",
19252004
))
19262005
"""
2006+
19272007
__module__ = "vertexai.language_models"
19282008

19292009
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
@@ -2074,7 +2154,9 @@ def predict_streaming(
20742154
)
20752155

20762156
prediction_service_client = self._endpoint._prediction_client
2077-
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
2157+
for (
2158+
prediction_dict
2159+
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
20782160
prediction_service_client=prediction_service_client,
20792161
endpoint_name=self._endpoint_name,
20802162
instance=prediction_request.instance,

0 commit comments

Comments
 (0)