Skip to content

Commit c29fa5d

Browse files
happy-qiaocopybara-github
authored andcommitted
fix: Tokenizers - Fixed Tokenizer.compute_tokens
PiperOrigin-RevId: 671042779
1 parent 6624ebe commit c29fa5d

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

tests/system/vertexai/test_tokenization.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from nltk.corpus import udhr
2121
from google.cloud import aiplatform
2222
from vertexai.preview.tokenization import (
23-
get_tokenizer_for_model,
23+
get_tokenizer_for_model as tokenizer_preview,
24+
)
25+
from vertexai.tokenization._tokenizers import (
26+
get_tokenizer_for_model as tokenizer_ga,
2427
)
2528
from vertexai.generative_models import (
2629
GenerativeModel,
@@ -44,8 +47,10 @@
4447
_CORPUS_LIB = [
4548
udhr,
4649
]
50+
_VERSIONED_TOKENIZER = [tokenizer_preview, tokenizer_ga]
4751
_MODEL_CORPUS_PARAMS = [
48-
(model_name, corpus_name, corpus_lib)
52+
(get_tokenizer_for_model, model_name, corpus_name, corpus_lib)
53+
for get_tokenizer_for_model in _VERSIONED_TOKENIZER
4954
for model_name in _MODELS
5055
for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB)
5156
]
@@ -125,11 +130,16 @@ def setup_method(self, api_endpoint_env_name):
125130
)
126131

127132
@pytest.mark.parametrize(
128-
"model_name, corpus_name, corpus_lib",
133+
"get_tokenizer_for_model, model_name, corpus_name, corpus_lib",
129134
_MODEL_CORPUS_PARAMS,
130135
)
131136
def test_count_tokens_local(
132-
self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
137+
self,
138+
get_tokenizer_for_model,
139+
model_name,
140+
corpus_name,
141+
corpus_lib,
142+
api_endpoint_env_name,
133143
):
134144
# The Gemini 1.5 flash model requires the model version
135145
# number suffix (001) in staging only
@@ -145,11 +155,16 @@ def test_count_tokens_local(
145155
assert service_result.total_tokens == local_result.total_tokens
146156

147157
@pytest.mark.parametrize(
148-
"model_name, corpus_name, corpus_lib",
158+
"get_tokenizer_for_model, model_name, corpus_name, corpus_lib",
149159
_MODEL_CORPUS_PARAMS,
150160
)
151161
def test_compute_tokens(
152-
self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
162+
self,
163+
get_tokenizer_for_model,
164+
model_name,
165+
corpus_name,
166+
corpus_lib,
167+
api_endpoint_env_name,
153168
):
154169
# The Gemini 1.5 flash model requires the model version
155170
# number suffix (001) in staging only
@@ -171,7 +186,7 @@ def test_compute_tokens(
171186
_MODELS,
172187
)
173188
def test_count_tokens_system_instruction(self, model_name):
174-
tokenizer = get_tokenizer_for_model(model_name)
189+
tokenizer = tokenizer_preview(model_name)
175190
model = GenerativeModel(model_name, system_instruction=["You are a chatbot."])
176191

177192
assert (
@@ -188,7 +203,7 @@ def test_count_tokens_system_instruction(self, model_name):
188203
def test_count_tokens_system_instruction_is_function_call(self, model_name):
189204
part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))
190205

191-
tokenizer = get_tokenizer_for_model(model_name)
206+
tokenizer = tokenizer_preview(model_name)
192207
model = GenerativeModel(model_name, system_instruction=[part])
193208

194209
assert (
@@ -204,7 +219,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name):
204219
part = Part._from_gapic(
205220
gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
206221
)
207-
tokenizer = get_tokenizer_for_model(model_name)
222+
tokenizer = tokenizer_preview(model_name)
208223
model = GenerativeModel(model_name, system_instruction=[part])
209224

210225
assert tokenizer.count_tokens(part, system_instruction=[part]).total_tokens
@@ -218,7 +233,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name):
218233
_MODELS,
219234
)
220235
def test_count_tokens_tool_is_function_declaration(self, model_name):
221-
tokenizer = get_tokenizer_for_model(model_name)
236+
tokenizer = tokenizer_preview(model_name)
222237
model = GenerativeModel(model_name)
223238
tool1 = Tool._from_gapic(
224239
gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_1])
@@ -241,7 +256,7 @@ def test_count_tokens_tool_is_function_declaration(self, model_name):
241256
)
242257
def test_count_tokens_content_is_function_call(self, model_name):
243258
part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))
244-
tokenizer = get_tokenizer_for_model(model_name)
259+
tokenizer = tokenizer_preview(model_name)
245260
model = GenerativeModel(model_name)
246261

247262
assert tokenizer.count_tokens(part).total_tokens
@@ -258,7 +273,7 @@ def test_count_tokens_content_is_function_response(self, model_name):
258273
part = Part._from_gapic(
259274
gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
260275
)
261-
tokenizer = get_tokenizer_for_model(model_name)
276+
tokenizer = tokenizer_preview(model_name)
262277
model = GenerativeModel(model_name)
263278

264279
assert tokenizer.count_tokens(part).total_tokens

vertexai/tokenization/_tokenizers.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,6 @@ class TokensInfo:
5353
role: str = None
5454

5555

56-
@dataclasses.dataclass(frozen=True)
57-
class ComputeTokensResult:
58-
tokens_info: Sequence[TokensInfo]
59-
60-
61-
class PreviewComputeTokensResult(ComputeTokensResult):
62-
def token_info_list(self) -> Sequence[TokensInfo]:
63-
import warnings
64-
65-
message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead."
66-
warnings.warn(message, DeprecationWarning, stacklevel=2)
67-
return self.tokens_info
68-
69-
7056
@dataclasses.dataclass(frozen=True)
7157
class ComputeTokensResult:
7258
"""Represents token string pieces and ids output in compute_tokens function.
@@ -78,11 +64,18 @@ class ComputeTokensResult:
7864
item represents each string instance. Each token
7965
info consists tokens list, token_ids list and
8066
a role.
81-
token_info_list: the value in this field equal to tokens_info.
8267
"""
8368

8469
tokens_info: Sequence[TokensInfo]
85-
token_info_list: Sequence[TokensInfo]
70+
71+
72+
class PreviewComputeTokensResult(ComputeTokensResult):
73+
def token_info_list(self) -> Sequence[TokensInfo]:
74+
import warnings
75+
76+
message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead."
77+
warnings.warn(message, DeprecationWarning, stacklevel=2)
78+
return self.tokens_info
8679

8780

8881
@dataclasses.dataclass(frozen=True)
@@ -169,7 +162,7 @@ def compute_tokens(
169162
role=role,
170163
)
171164
)
172-
return ComputeTokensResult(token_info_list=token_infos, tokens_info=token_infos)
165+
return ComputeTokensResult(tokens_info=token_infos)
173166

174167

175168
def _to_gapic_contents(
@@ -539,7 +532,9 @@ def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
539532

540533
class PreviewTokenizer(Tokenizer):
541534
def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult:
542-
return PreviewComputeTokensResult(tokens_info=super().compute_tokens(contents))
535+
return PreviewComputeTokensResult(
536+
tokens_info=super().compute_tokens(contents).tokens_info
537+
)
543538

544539

545540
def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer:

0 commit comments

Comments
 (0)