|
22 | 22 | from vertexai.preview.tokenization import (
|
23 | 23 | get_tokenizer_for_model,
|
24 | 24 | )
|
25 |
| -from vertexai.generative_models import GenerativeModel |
| 25 | +from vertexai.generative_models import ( |
| 26 | + GenerativeModel, |
| 27 | + Part, |
| 28 | + Tool, |
| 29 | +) |
26 | 30 | from tests.system.aiplatform import e2e_base
|
27 | 31 | from google import auth
|
| 32 | +from google.cloud.aiplatform_v1beta1.types import ( |
| 33 | + content as gapic_content_types, |
| 34 | + tool as gapic_tool_types, |
| 35 | + openapi, |
| 36 | +) |
| 37 | +from google.protobuf import struct_pb2 |
28 | 38 |
|
29 | 39 |
|
30 | 40 | _MODELS = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"]
|
|
39 | 49 | for model_name in _MODELS
|
40 | 50 | for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB)
|
41 | 51 | ]
|
| 52 | +_STRUCT = struct_pb2.Struct( |
| 53 | + fields={ |
| 54 | + "string_key": struct_pb2.Value(string_value="value"), |
| 55 | + } |
| 56 | +) |
| 57 | +_FUNCTION_CALL = gapic_tool_types.FunctionCall(name="test_function_call", args=_STRUCT) |
| 58 | +_FUNCTION_RESPONSE = gapic_tool_types.FunctionResponse( |
| 59 | + name="function_response", |
| 60 | + response=_STRUCT, |
| 61 | +) |
| 62 | + |
| 63 | + |
| 64 | +_SCHEMA_1 = openapi.Schema(format="schema1_format", description="schema1_description") |
| 65 | +_SCHEMA_2 = openapi.Schema(format="schema2_format", description="schema2_description") |
| 66 | +_EXAMPLE = struct_pb2.Value(string_value="value1") |
| 67 | + |
| 68 | +_FUNCTION_DECLARATION_1 = gapic_tool_types.FunctionDeclaration( |
| 69 | + name="function_declaration_name", |
| 70 | + description="function_declaration_description", |
| 71 | + parameters=openapi.Schema( |
| 72 | + format="schema_format", |
| 73 | + description="schema_description", |
| 74 | + enum=["schema_enum1", "schema_enum2"], |
| 75 | + required=["schema_required1", "schema_required2"], |
| 76 | + items=_SCHEMA_2, |
| 77 | + properties={"property_key": _SCHEMA_1}, |
| 78 | + example=_EXAMPLE, |
| 79 | + ), |
| 80 | +) |
| 81 | +_FUNCTION_DECLARATION_2 = gapic_tool_types.FunctionDeclaration( |
| 82 | + parameters=openapi.Schema( |
| 83 | + nullable=True, |
| 84 | + default=struct_pb2.Value(string_value="value1"), |
| 85 | + min_items=0, |
| 86 | + max_items=0, |
| 87 | + min_properties=0, |
| 88 | + max_properties=0, |
| 89 | + minimum=0, |
| 90 | + maximum=0, |
| 91 | + min_length=0, |
| 92 | + max_length=0, |
| 93 | + pattern="pattern", |
| 94 | + ), |
| 95 | + response=_SCHEMA_1, |
| 96 | +) |
42 | 97 |
|
43 | 98 | STAGING_API_ENDPOINT = "STAGING_ENDPOINT"
|
44 | 99 | PROD_API_ENDPOINT = "PROD_ENDPOINT"
|
@@ -107,8 +162,107 @@ def test_compute_tokens(
|
107 | 162 | text = corpus_lib.raw(book)
|
108 | 163 | response = model.compute_tokens(text)
|
109 | 164 | local_result = tokenizer.compute_tokens(text)
|
110 |
| - for local, service in zip( |
111 |
| - local_result.token_info_list, response.tokens_info |
112 |
| - ): |
| 165 | + for local, service in zip(local_result.tokens_info, response.tokens_info): |
113 | 166 | assert local.tokens == service.tokens
|
114 | 167 | assert local.token_ids == service.token_ids
|
| 168 | + |
| 169 | + @pytest.mark.parametrize( |
| 170 | + "model_name", |
| 171 | + _MODELS, |
| 172 | + ) |
| 173 | + def test_count_tokens_system_instruction(self, model_name): |
| 174 | + tokenizer = get_tokenizer_for_model(model_name) |
| 175 | + model = GenerativeModel(model_name, system_instruction=["You are a chatbot."]) |
| 176 | + |
| 177 | + assert ( |
| 178 | + tokenizer.count_tokens( |
| 179 | + "hello", system_instruction=["You are a chatbot."] |
| 180 | + ).total_tokens |
| 181 | + == model.count_tokens("hello").total_tokens |
| 182 | + ) |
| 183 | + |
| 184 | + @pytest.mark.parametrize( |
| 185 | + "model_name", |
| 186 | + _MODELS, |
| 187 | + ) |
| 188 | + def test_count_tokens_system_instruction_is_function_call(self, model_name): |
| 189 | + part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL)) |
| 190 | + |
| 191 | + tokenizer = get_tokenizer_for_model(model_name) |
| 192 | + model = GenerativeModel(model_name, system_instruction=[part]) |
| 193 | + |
| 194 | + assert ( |
| 195 | + tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens |
| 196 | + == model.count_tokens("hello").total_tokens |
| 197 | + ) |
| 198 | + |
| 199 | + @pytest.mark.parametrize( |
| 200 | + "model_name", |
| 201 | + _MODELS, |
| 202 | + ) |
| 203 | + def test_count_tokens_system_instruction_is_function_response(self, model_name): |
| 204 | + part = Part._from_gapic( |
| 205 | + gapic_content_types.Part(function_response=_FUNCTION_RESPONSE) |
| 206 | + ) |
| 207 | + tokenizer = get_tokenizer_for_model(model_name) |
| 208 | + model = GenerativeModel(model_name, system_instruction=[part]) |
| 209 | + |
| 210 | + assert tokenizer.count_tokens(part, system_instruction=[part]).total_tokens |
| 211 | + assert ( |
| 212 | + tokenizer.count_tokens("hello", system_instruction=[part]).total_tokens |
| 213 | + == model.count_tokens("hello").total_tokens |
| 214 | + ) |
| 215 | + |
| 216 | + @pytest.mark.parametrize( |
| 217 | + "model_name", |
| 218 | + _MODELS, |
| 219 | + ) |
| 220 | + def test_count_tokens_tool_is_function_declaration(self, model_name): |
| 221 | + tokenizer = get_tokenizer_for_model(model_name) |
| 222 | + model = GenerativeModel(model_name) |
| 223 | + tool1 = Tool._from_gapic( |
| 224 | + gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_1]) |
| 225 | + ) |
| 226 | + tool2 = Tool._from_gapic( |
| 227 | + gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_2]) |
| 228 | + ) |
| 229 | + |
| 230 | + assert tokenizer.count_tokens("hello", tools=[tool1]).total_tokens |
| 231 | + with pytest.raises(ValueError): |
| 232 | + tokenizer.count_tokens("hello", tools=[tool2]).total_tokens |
| 233 | + assert ( |
| 234 | + tokenizer.count_tokens("hello", tools=[tool1]).total_tokens |
| 235 | + == model.count_tokens("hello", tools=[tool1]).total_tokens |
| 236 | + ) |
| 237 | + |
| 238 | + @pytest.mark.parametrize( |
| 239 | + "model_name", |
| 240 | + _MODELS, |
| 241 | + ) |
| 242 | + def test_count_tokens_content_is_function_call(self, model_name): |
| 243 | + part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL)) |
| 244 | + tokenizer = get_tokenizer_for_model(model_name) |
| 245 | + model = GenerativeModel(model_name) |
| 246 | + |
| 247 | + assert tokenizer.count_tokens(part).total_tokens |
| 248 | + assert ( |
| 249 | + tokenizer.count_tokens(part).total_tokens |
| 250 | + == model.count_tokens(part).total_tokens |
| 251 | + ) |
| 252 | + |
| 253 | + @pytest.mark.parametrize( |
| 254 | + "model_name", |
| 255 | + _MODELS, |
| 256 | + ) |
| 257 | + def test_count_tokens_content_is_function_response(self, model_name): |
| 258 | + part = Part._from_gapic( |
| 259 | + gapic_content_types.Part(function_response=_FUNCTION_RESPONSE) |
| 260 | + ) |
| 261 | + tokenizer = get_tokenizer_for_model(model_name) |
| 262 | + model = GenerativeModel(model_name) |
| 263 | + |
| 264 | + assert tokenizer.count_tokens(part).total_tokens |
| 265 | + assert ( |
| 266 | + tokenizer.count_tokens(part).total_tokens |
| 267 | + == model.count_tokens(part).total_tokens |
| 268 | + ) |
0 commit comments