Skip to content

Commit 72fcc06

Browse files
happy-qiaocopybara-github
authored andcommitted
feat: Add support for system instruction and tools in tokenization.
PiperOrigin-RevId: 669058979
1 parent 50fca69 commit 72fcc06

File tree

3 files changed

+701
-83
lines changed

3 files changed

+701
-83
lines changed

tests/system/vertexai/test_tokenization.py

+158-4
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,19 @@
2222
from vertexai.preview.tokenization import (
2323
get_tokenizer_for_model,
2424
)
25-
from vertexai.generative_models import GenerativeModel
25+
from vertexai.generative_models import (
26+
GenerativeModel,
27+
Part,
28+
Tool,
29+
)
2630
from tests.system.aiplatform import e2e_base
2731
from google import auth
32+
from google.cloud.aiplatform_v1beta1.types import (
33+
content as gapic_content_types,
34+
tool as gapic_tool_types,
35+
openapi,
36+
)
37+
from google.protobuf import struct_pb2
2838

2939

3040
_MODELS = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"]
@@ -39,6 +49,51 @@
3949
for model_name in _MODELS
4050
for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB)
4151
]
52+
_STRUCT = struct_pb2.Struct(
53+
fields={
54+
"string_key": struct_pb2.Value(string_value="value"),
55+
}
56+
)
57+
_FUNCTION_CALL = gapic_tool_types.FunctionCall(name="test_function_call", args=_STRUCT)
58+
_FUNCTION_RESPONSE = gapic_tool_types.FunctionResponse(
59+
name="function_response",
60+
response=_STRUCT,
61+
)
62+
63+
64+
_SCHEMA_1 = openapi.Schema(format="schema1_format", description="schema1_description")
65+
_SCHEMA_2 = openapi.Schema(format="schema2_format", description="schema2_description")
66+
_EXAMPLE = struct_pb2.Value(string_value="value1")
67+
68+
_FUNCTION_DECLARATION_1 = gapic_tool_types.FunctionDeclaration(
69+
name="function_declaration_name",
70+
description="function_declaration_description",
71+
parameters=openapi.Schema(
72+
format="schema_format",
73+
description="schema_description",
74+
enum=["schema_enum1", "schema_enum2"],
75+
required=["schema_required1", "schema_required2"],
76+
items=_SCHEMA_2,
77+
properties={"property_key": _SCHEMA_1},
78+
example=_EXAMPLE,
79+
),
80+
)
81+
_FUNCTION_DECLARATION_2 = gapic_tool_types.FunctionDeclaration(
82+
parameters=openapi.Schema(
83+
nullable=True,
84+
default=struct_pb2.Value(string_value="value1"),
85+
min_items=0,
86+
max_items=0,
87+
min_properties=0,
88+
max_properties=0,
89+
minimum=0,
90+
maximum=0,
91+
min_length=0,
92+
max_length=0,
93+
pattern="pattern",
94+
),
95+
response=_SCHEMA_1,
96+
)
4297

4398
STAGING_API_ENDPOINT = "STAGING_ENDPOINT"
4499
PROD_API_ENDPOINT = "PROD_ENDPOINT"
@@ -107,8 +162,107 @@ def test_compute_tokens(
107162
text = corpus_lib.raw(book)
108163
response = model.compute_tokens(text)
109164
local_result = tokenizer.compute_tokens(text)
110-
for local, service in zip(
111-
local_result.token_info_list, response.tokens_info
112-
):
165+
for local, service in zip(local_result.tokens_info, response.tokens_info):
113166
assert local.tokens == service.tokens
114167
assert local.token_ids == service.token_ids
168+
169+
@pytest.mark.parametrize(
170+
"model_name",
171+
_MODELS,
172+
)
173+
def test_count_tokens_system_instruction(self, model_name):
174+
tokenizer = get_tokenizer_for_model(model_name)
175+
model = GenerativeModel(model_name, system_instruction=["You are a chatbot."])
176+
177+
assert (
178+
tokenizer.count_tokens(
179+
"hello", system_instruction=["You are a chatbot."]
180+
).total_tokens
181+
== model.count_tokens("hello").total_tokens
182+
)
183+
184+
@pytest.mark.parametrize(
185+
"model_name",
186+
_MODELS,
187+
)
188+
def test_count_tokens_system_instruction_is_function_call(self, model_name):
189+
part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))
190+
191+
tokenizer = get_tokenizer_for_model(model_name)
192+
model = GenerativeModel(model_name, system_instruction=[part])
193+
194+
assert (
195+
tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens
196+
== model.count_tokens("hello").total_tokens
197+
)
198+
199+
@pytest.mark.parametrize(
200+
"model_name",
201+
_MODELS,
202+
)
203+
def test_count_tokens_system_instruction_is_function_response(self, model_name):
204+
part = Part._from_gapic(
205+
gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
206+
)
207+
tokenizer = get_tokenizer_for_model(model_name)
208+
model = GenerativeModel(model_name, system_instruction=[part])
209+
210+
assert tokenizer.count_tokens(part, system_instruction=[part]).total_tokens
211+
assert (
212+
tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens
213+
== model.count_tokens("hello").total_tokens
214+
)
215+
216+
@pytest.mark.parametrize(
217+
"model_name",
218+
_MODELS,
219+
)
220+
def test_count_tokens_tool_is_function_declaration(self, model_name):
221+
tokenizer = get_tokenizer_for_model(model_name)
222+
model = GenerativeModel(model_name)
223+
tool1 = Tool._from_gapic(
224+
gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_1])
225+
)
226+
tool2 = Tool._from_gapic(
227+
gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_2])
228+
)
229+
230+
assert tokenizer.count_tokens("hello", tools=[tool1]).total_tokens
231+
with pytest.raises(ValueError):
232+
tokenizer.count_tokens("hello", tools=[tool2]).total_tokens
233+
assert (
234+
tokenizer.count_tokens("hello", tools=[tool1]).total_tokens
235+
== model.count_tokens("hello", tools=[tool1]).total_tokens
236+
)
237+
238+
@pytest.mark.parametrize(
239+
"model_name",
240+
_MODELS,
241+
)
242+
def test_count_tokens_content_is_function_call(self, model_name):
243+
part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))
244+
tokenizer = get_tokenizer_for_model(model_name)
245+
model = GenerativeModel(model_name)
246+
247+
assert tokenizer.count_tokens(part).total_tokens
248+
assert (
249+
tokenizer.count_tokens(part).total_tokens
250+
== model.count_tokens(part).total_tokens
251+
)
252+
253+
@pytest.mark.parametrize(
254+
"model_name",
255+
_MODELS,
256+
)
257+
def test_count_tokens_content_is_function_response(self, model_name):
258+
part = Part._from_gapic(
259+
gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
260+
)
261+
tokenizer = get_tokenizer_for_model(model_name)
262+
model = GenerativeModel(model_name)
263+
264+
assert tokenizer.count_tokens(part).total_tokens
265+
assert (
266+
tokenizer.count_tokens(part).total_tokens
267+
== model.count_tokens(part).total_tokens
268+
)

0 commit comments

Comments
 (0)