Skip to content

Commit a8a4c30

Browse files
happy-qiaocopybara-github
authored andcommitted
feat: Tokenization - Added count_tokens support for local tokenization
Usage: ``` tokenizer = get_tokenizer_for_model("gemini-1.0-pro-001") print(tokenizer.count_tokens("Hello world!")) ``` PiperOrigin-RevId: 646753548
1 parent c46f3e9 commit a8a4c30

File tree

5 files changed

+737
-0
lines changed

5 files changed

+737
-0
lines changed

setup.py

+4
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@
169169
)
170170
)
171171

172+
tokenization_extra_require = ["sentencepiece >= 0.2.0"]
173+
172174
full_extra_require = list(
173175
set(
174176
tensorboard_extra_require
@@ -191,6 +193,7 @@
191193
testing_extra_require = (
192194
full_extra_require
193195
+ profiler_extra_require
196+
+ tokenization_extra_require
194197
+ [
195198
"bigframes; python_version>='3.10'",
196199
# google-api-core 2.x is required since kfp requires protobuf > 4
@@ -273,6 +276,7 @@
273276
"rapid_evaluation": rapid_evaluation_extra_require,
274277
"langchain": langchain_extra_require,
275278
"langchain_testing": langchain_testing_extra_require,
279+
"tokenization": tokenization_extra_require,
276280
},
277281
python_requires=">=3.8",
278282
classifiers=[
+310
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import hashlib
17+
import io
18+
import os
19+
import tempfile
20+
import shutil
21+
from typing import List
22+
from unittest import mock
23+
from vertexai.generative_models import Content, Image, Part
24+
from vertexai.tokenization import _tokenizer_loading
25+
from vertexai.tokenization._tokenizers import (
26+
CountTokensResult,
27+
get_tokenizer_for_model,
28+
)
29+
import pytest
30+
from sentencepiece import sentencepiece_model_pb2
31+
import sentencepiece as spm
32+
33+
_TOKENIZER_NAME = "google/gemma"
34+
_MODEL_NAME = "gemini-1.5-pro"
35+
36+
_SENTENCE_1 = "hello world"
37+
_SENTENCE_2 = "what's the weather today"
38+
_SENTENCE_3 = "It's 70 degrees."
39+
_EMPTY_SENTENCE = ""
40+
41+
_TOKENS_MAP = {
42+
_EMPTY_SENTENCE: {"ids": []},
43+
_SENTENCE_1: {"ids": [1, 2]},
44+
_SENTENCE_2: {"ids": [4, 5, 6, 7, 8, 9]},
45+
_SENTENCE_3: {"ids": [7, 8, 9, 10, 11, 12, 13]},
46+
}
47+
48+
49+
_VALID_CONTENTS_TYPE = [
50+
(_EMPTY_SENTENCE, [_EMPTY_SENTENCE], []),
51+
(_SENTENCE_1, [_SENTENCE_1], [_TOKENS_MAP[_SENTENCE_1]["ids"]]),
52+
(
53+
[_SENTENCE_1, _SENTENCE_2],
54+
[_SENTENCE_1, _SENTENCE_2],
55+
[_TOKENS_MAP[_SENTENCE_1]["ids"], _TOKENS_MAP[_SENTENCE_2]["ids"]],
56+
),
57+
(
58+
Part.from_text(_SENTENCE_1),
59+
[_SENTENCE_1],
60+
[_TOKENS_MAP[_SENTENCE_1]["ids"]],
61+
),
62+
(
63+
[
64+
Part.from_text(_SENTENCE_1),
65+
Part.from_text(_SENTENCE_2),
66+
],
67+
[_SENTENCE_1, _SENTENCE_2],
68+
[_TOKENS_MAP[_SENTENCE_1]["ids"], _TOKENS_MAP[_SENTENCE_2]["ids"]],
69+
),
70+
(
71+
Content(role="user", parts=[Part.from_text(_SENTENCE_1)]),
72+
[_SENTENCE_1],
73+
[_TOKENS_MAP[_SENTENCE_1]["ids"]],
74+
),
75+
(
76+
Content(
77+
role="user",
78+
parts=[
79+
Part.from_text(_SENTENCE_1),
80+
Part.from_text(_SENTENCE_2),
81+
],
82+
),
83+
[_SENTENCE_1, _SENTENCE_2],
84+
[_TOKENS_MAP[_SENTENCE_1]["ids"], _TOKENS_MAP[_SENTENCE_2]["ids"]],
85+
),
86+
(
87+
[
88+
Content(
89+
role="user",
90+
parts=[
91+
Part.from_text(_SENTENCE_1),
92+
Part.from_text(_SENTENCE_2),
93+
],
94+
),
95+
Content(
96+
role="model",
97+
parts=[
98+
Part.from_text(_SENTENCE_3),
99+
],
100+
),
101+
],
102+
[_SENTENCE_1, _SENTENCE_2, _SENTENCE_3],
103+
[
104+
_TOKENS_MAP[_SENTENCE_1]["ids"],
105+
_TOKENS_MAP[_SENTENCE_2]["ids"],
106+
_TOKENS_MAP[_SENTENCE_3]["ids"],
107+
],
108+
),
109+
(
110+
[
111+
{
112+
"role": "user",
113+
"parts": [
114+
{"text": _SENTENCE_1},
115+
{"text": _SENTENCE_2},
116+
],
117+
},
118+
{"role": "model", "parts": [{"text": _SENTENCE_3}]},
119+
],
120+
[_SENTENCE_1, _SENTENCE_2, _SENTENCE_3],
121+
[
122+
_TOKENS_MAP[_SENTENCE_1]["ids"],
123+
_TOKENS_MAP[_SENTENCE_2]["ids"],
124+
_TOKENS_MAP[_SENTENCE_3]["ids"],
125+
],
126+
),
127+
]
128+
129+
130+
_LIST_OF_UNSUPPORTED_CONTENTS = [
131+
Part.from_uri("gs://bucket/object", mime_type="mime_type"),
132+
Part.from_data(b"inline_data_bytes", mime_type="mime_type"),
133+
Part.from_dict({"function_call": {"name": "test_function_call"}}),
134+
Part.from_dict({"function_response": {"name": "test_function_response"}}),
135+
Part.from_dict({"video_metadata": {"start_offset": "10s"}}),
136+
Content(
137+
role="user",
138+
parts=[Part.from_uri("gs://bucket/object", mime_type="mime_type")],
139+
),
140+
Content(
141+
role="user",
142+
parts=[Part.from_data(b"inline_data_bytes", mime_type="mime_type")],
143+
),
144+
Content(
145+
role="user",
146+
parts=[Part.from_dict({"function_call": {"name": "test_function_call"}})],
147+
),
148+
Content(
149+
role="user",
150+
parts=[
151+
Part.from_dict({"function_response": {"name": "test_function_response"}})
152+
],
153+
),
154+
Content(
155+
role="user",
156+
parts=[Part.from_dict({"video_metadata": {"start_offset": "10s"}})],
157+
),
158+
]
159+
160+
161+
@pytest.fixture
162+
def mock_sp_processor():
163+
with mock.patch.object(
164+
spm,
165+
"SentencePieceProcessor",
166+
) as sp_mock:
167+
sp_mock.return_value.LoadFromSerializedProto.return_value = True
168+
sp_mock.return_value.encode.side_effect = _encode_as_ids
169+
yield sp_mock
170+
171+
172+
def _encode_as_ids(contents: List[str]):
173+
return [_TOKENS_MAP[content]["ids"] for content in contents]
174+
175+
176+
@pytest.fixture
177+
def mock_requests_get():
178+
with mock.patch("requests.get") as requests_get_mock:
179+
model = sentencepiece_model_pb2.ModelProto()
180+
requests_get_mock.return_value.content = model.SerializeToString()
181+
yield requests_get_mock
182+
183+
184+
@pytest.fixture
185+
def mock_hashlib_sha256():
186+
with mock.patch("hashlib.sha256") as sha256_mock:
187+
sha256_mock.return_value.hexdigest.return_value = (
188+
"61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
189+
)
190+
yield sha256_mock
191+
192+
193+
@pytest.mark.usefixtures("mock_requests_get", "mock_hashlib_sha256")
194+
class TestTokenizers:
195+
"""Unit tests for the tokenizers."""
196+
197+
@pytest.mark.parametrize(
198+
"contents, encode_input, encode_output",
199+
_VALID_CONTENTS_TYPE,
200+
)
201+
def test_count_tokens_valid_contents_type(
202+
self, mock_sp_processor, contents, encode_input, encode_output
203+
):
204+
_tokenizer_loading.get_sentencepiece.cache_clear()
205+
expected_count = CountTokensResult(
206+
sum(
207+
1 if isinstance(output, int) else len(output)
208+
for output in encode_output
209+
)
210+
)
211+
assert (
212+
get_tokenizer_for_model(_MODEL_NAME).count_tokens(contents)
213+
== expected_count
214+
)
215+
mock_sp_processor.return_value.encode.assert_called_once_with(encode_input)
216+
217+
@pytest.mark.parametrize(
218+
"contents",
219+
_LIST_OF_UNSUPPORTED_CONTENTS,
220+
)
221+
def test_count_tokens_unsupported_contents_type(
222+
self,
223+
mock_sp_processor,
224+
contents,
225+
):
226+
_tokenizer_loading.get_sentencepiece.cache_clear()
227+
with pytest.raises(ValueError) as e:
228+
get_tokenizer_for_model(_MODEL_NAME).count_tokens(contents)
229+
e.match("Tokenizers do not support non-text content types.")
230+
231+
def test_image_mime_types(self, mock_sp_processor):
232+
# Importing external library lazily to reduce the scope of import errors.
233+
from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top
234+
235+
pil_image: PIL_Image.Image = PIL_Image.new(mode="RGB", size=(200, 200))
236+
image_bytes_io = io.BytesIO()
237+
pil_image.save(image_bytes_io, format="PNG")
238+
_tokenizer_loading.get_sentencepiece.cache_clear()
239+
with pytest.raises(ValueError) as e:
240+
get_tokenizer_for_model(_MODEL_NAME).count_tokens(
241+
Image.from_bytes(image_bytes_io.getvalue())
242+
)
243+
e.match("Tokenizers do not support Image content type.")
244+
245+
246+
class TestModelLoad:
247+
def setup_method(self):
248+
model_dir = os.path.join(tempfile.gettempdir(), "vertexai_tokenizer_model")
249+
if os.path.exists(model_dir):
250+
shutil.rmtree(model_dir)
251+
if not os.path.exists(model_dir):
252+
os.mkdir(model_dir)
253+
254+
def get_cache_path(self, file_url: str):
255+
model_dir = os.path.join(tempfile.gettempdir(), "vertexai_tokenizer_model")
256+
filename = hashlib.sha1(file_url.encode()).hexdigest()
257+
return os.path.join(model_dir, filename)
258+
259+
def test_download_and_save_to_cache(self, mock_hashlib_sha256, mock_requests_get):
260+
_tokenizer_loading._load_model_proto(_TOKENIZER_NAME)
261+
cache_path = self.get_cache_path(
262+
_tokenizer_loading._TOKENIZERS[_TOKENIZER_NAME].model_url
263+
)
264+
assert os.path.exists(cache_path)
265+
mock_requests_get.assert_called_once()
266+
with open(cache_path, "rb") as f:
267+
assert f.read() == sentencepiece_model_pb2.ModelProto().SerializeToString()
268+
269+
@mock.patch("hashlib.sha256", autospec=True)
270+
def test_download_file_is_corrupted(self, hash_mock, mock_requests_get):
271+
hash_mock.return_value.hexdigest.return_value = "inconsistent_hash"
272+
with pytest.raises(ValueError) as e:
273+
_tokenizer_loading._load_model_proto(_TOKENIZER_NAME)
274+
e.match(regexp=r"Downloaded model file is corrupted.*")
275+
276+
mock_requests_get.assert_called_once()
277+
278+
def test_load_model_proto_from_cache(self, mock_hashlib_sha256, mock_requests_get):
279+
cache_path = self.get_cache_path(
280+
_tokenizer_loading._TOKENIZERS[_TOKENIZER_NAME].model_url
281+
)
282+
model_contents = sentencepiece_model_pb2.ModelProto(
283+
pieces=[sentencepiece_model_pb2.ModelProto.SentencePiece(piece="a")]
284+
).SerializeToString()
285+
with open(cache_path, "wb") as f:
286+
f.write(model_contents)
287+
assert _tokenizer_loading._load_model_proto(_TOKENIZER_NAME) == model_contents
288+
assert os.path.exists(cache_path)
289+
mock_requests_get.assert_not_called()
290+
291+
@mock.patch("hashlib.sha256", autospec=True)
292+
def test_load_model_proto_from_corrupted_cache(self, hash_mock, mock_requests_get):
293+
cache_path = self.get_cache_path(
294+
_tokenizer_loading._TOKENIZERS[_TOKENIZER_NAME].model_url
295+
)
296+
model_contents = sentencepiece_model_pb2.ModelProto(
297+
pieces=[sentencepiece_model_pb2.ModelProto.SentencePiece(piece="a")]
298+
).SerializeToString()
299+
with open(cache_path, "wb") as f:
300+
f.write(model_contents)
301+
hash_mock.return_value.hexdigest.side_effect = [
302+
"inconsistent_hash", # first read from cache
303+
_tokenizer_loading._TOKENIZERS[
304+
_TOKENIZER_NAME
305+
].model_hash, # then read from network
306+
]
307+
_tokenizer_loading._load_model_proto(_TOKENIZER_NAME)
308+
mock_requests_get.assert_called_once()
309+
with open(cache_path, "rb") as f:
310+
assert f.read() == sentencepiece_model_pb2.ModelProto().SerializeToString()

vertexai/preview/tokenization.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
# We just want to re-export certain classes
17+
# pylint: disable=g-multiple-import,g-importing-member
18+
from vertexai.tokenization._tokenizers import (
19+
get_tokenizer_for_model,
20+
)
21+
22+
23+
__all__ = ["get_tokenizer_for_model"]

0 commit comments

Comments
 (0)