20
20
from nltk .corpus import udhr
21
21
from google .cloud import aiplatform
22
22
from vertexai .preview .tokenization import (
23
- get_tokenizer_for_model ,
23
+ get_tokenizer_for_model as tokenizer_preview ,
24
+ )
25
+ from vertexai .tokenization ._tokenizers import (
26
+ get_tokenizer_for_model as tokenizer_ga ,
24
27
)
25
28
from vertexai .generative_models import (
26
29
GenerativeModel ,
44
47
_CORPUS_LIB = [
45
48
udhr ,
46
49
]
50
+ _VERSIONED_TOKENIZER = [tokenizer_preview , tokenizer_ga ]
47
51
_MODEL_CORPUS_PARAMS = [
48
- (model_name , corpus_name , corpus_lib )
52
+ (get_tokenizer_for_model , model_name , corpus_name , corpus_lib )
53
+ for get_tokenizer_for_model in _VERSIONED_TOKENIZER
49
54
for model_name in _MODELS
50
55
for (corpus_name , corpus_lib ) in zip (_CORPUS , _CORPUS_LIB )
51
56
]
@@ -125,11 +130,16 @@ def setup_method(self, api_endpoint_env_name):
125
130
)
126
131
127
132
@pytest .mark .parametrize (
128
- "model_name, corpus_name, corpus_lib" ,
133
+ "get_tokenizer_for_model, model_name, corpus_name, corpus_lib" ,
129
134
_MODEL_CORPUS_PARAMS ,
130
135
)
131
136
def test_count_tokens_local (
132
- self , model_name , corpus_name , corpus_lib , api_endpoint_env_name
137
+ self ,
138
+ get_tokenizer_for_model ,
139
+ model_name ,
140
+ corpus_name ,
141
+ corpus_lib ,
142
+ api_endpoint_env_name ,
133
143
):
134
144
# The Gemini 1.5 flash model requires the model version
135
145
# number suffix (001) in staging only
@@ -145,11 +155,16 @@ def test_count_tokens_local(
145
155
assert service_result .total_tokens == local_result .total_tokens
146
156
147
157
@pytest .mark .parametrize (
148
- "model_name, corpus_name, corpus_lib" ,
158
+ "get_tokenizer_for_model, model_name, corpus_name, corpus_lib" ,
149
159
_MODEL_CORPUS_PARAMS ,
150
160
)
151
161
def test_compute_tokens (
152
- self , model_name , corpus_name , corpus_lib , api_endpoint_env_name
162
+ self ,
163
+ get_tokenizer_for_model ,
164
+ model_name ,
165
+ corpus_name ,
166
+ corpus_lib ,
167
+ api_endpoint_env_name ,
153
168
):
154
169
# The Gemini 1.5 flash model requires the model version
155
170
# number suffix (001) in staging only
@@ -171,7 +186,7 @@ def test_compute_tokens(
171
186
_MODELS ,
172
187
)
173
188
def test_count_tokens_system_instruction (self , model_name ):
174
- tokenizer = get_tokenizer_for_model (model_name )
189
+ tokenizer = tokenizer_preview (model_name )
175
190
model = GenerativeModel (model_name , system_instruction = ["You are a chatbot." ])
176
191
177
192
assert (
@@ -188,7 +203,7 @@ def test_count_tokens_system_instruction(self, model_name):
188
203
def test_count_tokens_system_instruction_is_function_call (self , model_name ):
189
204
part = Part ._from_gapic (gapic_content_types .Part (function_call = _FUNCTION_CALL ))
190
205
191
- tokenizer = get_tokenizer_for_model (model_name )
206
+ tokenizer = tokenizer_preview (model_name )
192
207
model = GenerativeModel (model_name , system_instruction = [part ])
193
208
194
209
assert (
@@ -204,7 +219,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name):
204
219
part = Part ._from_gapic (
205
220
gapic_content_types .Part (function_response = _FUNCTION_RESPONSE )
206
221
)
207
- tokenizer = get_tokenizer_for_model (model_name )
222
+ tokenizer = tokenizer_preview (model_name )
208
223
model = GenerativeModel (model_name , system_instruction = [part ])
209
224
210
225
assert tokenizer .count_tokens (part , system_instruction = [part ]).total_tokens
@@ -218,7 +233,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name):
218
233
_MODELS ,
219
234
)
220
235
def test_count_tokens_tool_is_function_declaration (self , model_name ):
221
- tokenizer = get_tokenizer_for_model (model_name )
236
+ tokenizer = tokenizer_preview (model_name )
222
237
model = GenerativeModel (model_name )
223
238
tool1 = Tool ._from_gapic (
224
239
gapic_tool_types .Tool (function_declarations = [_FUNCTION_DECLARATION_1 ])
@@ -241,7 +256,7 @@ def test_count_tokens_tool_is_function_declaration(self, model_name):
241
256
)
242
257
def test_count_tokens_content_is_function_call (self , model_name ):
243
258
part = Part ._from_gapic (gapic_content_types .Part (function_call = _FUNCTION_CALL ))
244
- tokenizer = get_tokenizer_for_model (model_name )
259
+ tokenizer = tokenizer_preview (model_name )
245
260
model = GenerativeModel (model_name )
246
261
247
262
assert tokenizer .count_tokens (part ).total_tokens
@@ -258,7 +273,7 @@ def test_count_tokens_content_is_function_response(self, model_name):
258
273
part = Part ._from_gapic (
259
274
gapic_content_types .Part (function_response = _FUNCTION_RESPONSE )
260
275
)
261
- tokenizer = get_tokenizer_for_model (model_name )
276
+ tokenizer = tokenizer_preview (model_name )
262
277
model = GenerativeModel (model_name )
263
278
264
279
assert tokenizer .count_tokens (part ).total_tokens
0 commit comments