18
18
"""System tests for generative models."""
19
19
20
20
import json
21
+ import os
21
22
import pytest
22
23
23
24
# Google imports
36
37
GEMINI_15_MODEL_NAME = "gemini-1.5-pro-preview-0409"
37
38
GEMINI_15_PRO_MODEL_NAME = "gemini-1.5-pro-001"
38
39
40
+ STAGING_API_ENDPOINT = os .getenv ("STAGING_ENDPOINT" )
41
+ PROD_API_ENDPOINT = None
42
+
39
43
40
44
# A dummy function for function calling
41
45
def get_current_weather (location : str , unit : str = "centigrade" ):
@@ -84,12 +88,14 @@ def get_current_weather(location: str, unit: str = "centigrade"):
84
88
}
85
89
86
90
91
+ @pytest .mark .parametrize ("api_endpoint" , [STAGING_API_ENDPOINT , PROD_API_ENDPOINT ])
87
92
class TestGenerativeModels (e2e_base .TestEndToEnd ):
88
93
"""System tests for generative models."""
89
94
90
95
_temp_prefix = "temp_generative_models_test_"
91
96
92
- def setup_method (self ):
97
+ @pytest .fixture (scope = "function" , autouse = True )
98
+ def setup_method (self , api_endpoint ):
93
99
super ().setup_method ()
94
100
credentials , _ = auth .default (
95
101
scopes = ["https://www.googleapis.com/auth/cloud-platform" ]
@@ -98,9 +104,10 @@ def setup_method(self):
98
104
project = e2e_base ._PROJECT ,
99
105
location = e2e_base ._LOCATION ,
100
106
credentials = credentials ,
107
+ api_endpoint = api_endpoint ,
101
108
)
102
109
103
- def test_generate_content_with_cached_content_from_text (self ):
110
+ def test_generate_content_with_cached_content_from_text (self , api_endpoint ):
104
111
cached_content = caching .CachedContent .create (
105
112
model_name = GEMINI_15_PRO_MODEL_NAME ,
106
113
system_instruction = "Please answer all the questions like a pirate." ,
@@ -138,7 +145,7 @@ def test_generate_content_with_cached_content_from_text(self):
138
145
finally :
139
146
cached_content .delete ()
140
147
141
- def test_generate_content_from_text (self ):
148
+ def test_generate_content_from_text (self , api_endpoint ):
142
149
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
143
150
response = model .generate_content (
144
151
"Why is sky blue?" ,
@@ -147,15 +154,15 @@ def test_generate_content_from_text(self):
147
154
assert response .text
148
155
149
156
@pytest .mark .asyncio
150
- async def test_generate_content_async (self ):
157
+ async def test_generate_content_async (self , api_endpoint ):
151
158
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
152
159
response = await model .generate_content_async (
153
160
"Why is sky blue?" ,
154
161
generation_config = generative_models .GenerationConfig (temperature = 0 ),
155
162
)
156
163
assert response .text
157
164
158
- def test_generate_content_streaming (self ):
165
+ def test_generate_content_streaming (self , api_endpoint ):
159
166
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
160
167
stream = model .generate_content (
161
168
"Why is sky blue?" ,
@@ -170,7 +177,7 @@ def test_generate_content_streaming(self):
170
177
)
171
178
172
179
@pytest .mark .asyncio
173
- async def test_generate_content_streaming_async (self ):
180
+ async def test_generate_content_streaming_async (self , api_endpoint ):
174
181
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
175
182
async_stream = await model .generate_content_async (
176
183
"Why is sky blue?" ,
@@ -184,7 +191,7 @@ async def test_generate_content_streaming_async(self):
184
191
is generative_models .FinishReason .STOP
185
192
)
186
193
187
- def test_generate_content_with_parameters (self ):
194
+ def test_generate_content_with_parameters (self , api_endpoint ):
188
195
model = generative_models .GenerativeModel (
189
196
GEMINI_MODEL_NAME ,
190
197
system_instruction = [
@@ -211,7 +218,7 @@ def test_generate_content_with_parameters(self):
211
218
)
212
219
assert response .text
213
220
214
- def test_generate_content_with_gemini_15_parameters (self ):
221
+ def test_generate_content_with_gemini_15_parameters (self , api_endpoint ):
215
222
model = generative_models .GenerativeModel (GEMINI_15_MODEL_NAME )
216
223
response = model .generate_content (
217
224
contents = "Why is sky blue? Respond in JSON Format." ,
@@ -237,7 +244,7 @@ def test_generate_content_with_gemini_15_parameters(self):
237
244
assert response .text
238
245
assert json .loads (response .text )
239
246
240
- def test_generate_content_from_list_of_content_dict (self ):
247
+ def test_generate_content_from_list_of_content_dict (self , api_endpoint ):
241
248
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
242
249
response = model .generate_content (
243
250
contents = [{"role" : "user" , "parts" : [{"text" : "Why is sky blue?" }]}],
@@ -248,7 +255,7 @@ def test_generate_content_from_list_of_content_dict(self):
248
255
@pytest .mark .skip (
249
256
reason = "Breaking change in the gemini-pro-vision model. See b/315803556#comment3"
250
257
)
251
- def test_generate_content_from_remote_image (self ):
258
+ def test_generate_content_from_remote_image (self , api_endpoint ):
252
259
vision_model = generative_models .GenerativeModel (GEMINI_VISION_MODEL_NAME )
253
260
image_part = generative_models .Part .from_uri (
254
261
uri = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg" ,
@@ -261,7 +268,7 @@ def test_generate_content_from_remote_image(self):
261
268
assert response .text
262
269
assert "cat" in response .text
263
270
264
- def test_generate_content_from_text_and_remote_image (self ):
271
+ def test_generate_content_from_text_and_remote_image (self , api_endpoint ):
265
272
vision_model = generative_models .GenerativeModel (GEMINI_VISION_MODEL_NAME )
266
273
image_part = generative_models .Part .from_uri (
267
274
uri = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg" ,
@@ -274,7 +281,7 @@ def test_generate_content_from_text_and_remote_image(self):
274
281
assert response .text
275
282
assert "cat" in response .text
276
283
277
- def test_generate_content_from_text_and_remote_video (self ):
284
+ def test_generate_content_from_text_and_remote_video (self , api_endpoint ):
278
285
vision_model = generative_models .GenerativeModel (GEMINI_VISION_MODEL_NAME )
279
286
video_part = generative_models .Part .from_uri (
280
287
uri = "gs://cloud-samples-data/video/animals.mp4" ,
@@ -287,7 +294,7 @@ def test_generate_content_from_text_and_remote_video(self):
287
294
assert response .text
288
295
assert "Zootopia" in response .text
289
296
290
- def test_grounding_google_search_retriever (self ):
297
+ def test_grounding_google_search_retriever (self , api_endpoint ):
291
298
model = preview_generative_models .GenerativeModel (GEMINI_MODEL_NAME )
292
299
google_search_retriever_tool = (
293
300
preview_generative_models .Tool .from_google_search_retrieval (
@@ -309,7 +316,7 @@ def test_grounding_google_search_retriever(self):
309
316
310
317
# Chat
311
318
312
- def test_send_message_from_text (self ):
319
+ def test_send_message_from_text (self , api_endpoint ):
313
320
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
314
321
chat = model .start_chat ()
315
322
response1 = chat .send_message (
@@ -326,7 +333,7 @@ def test_send_message_from_text(self):
326
333
assert response2 .text
327
334
assert len (chat .history ) == 4
328
335
329
- def test_chat_function_calling (self ):
336
+ def test_chat_function_calling (self , api_endpoint ):
330
337
get_current_weather_func = generative_models .FunctionDeclaration (
331
338
name = "get_current_weather" ,
332
339
description = "Get the current weather in a given location" ,
@@ -360,7 +367,7 @@ def test_chat_function_calling(self):
360
367
)
361
368
assert response2 .text
362
369
363
- def test_generate_content_function_calling (self ):
370
+ def test_generate_content_function_calling (self , api_endpoint ):
364
371
get_current_weather_func = generative_models .FunctionDeclaration (
365
372
name = "get_current_weather" ,
366
373
description = "Get the current weather in a given location" ,
@@ -440,7 +447,7 @@ def test_generate_content_function_calling(self):
440
447
441
448
assert summary
442
449
443
- def test_chat_automatic_function_calling (self ):
450
+ def test_chat_automatic_function_calling (self , api_endpoint ):
444
451
get_current_weather_func = generative_models .FunctionDeclaration .from_func (
445
452
get_current_weather
446
453
)
@@ -471,7 +478,7 @@ def test_chat_automatic_function_calling(self):
471
478
assert chat .history [- 2 ].parts [0 ].function_response
472
479
assert chat .history [- 2 ].parts [0 ].function_response .name == "get_current_weather"
473
480
474
- def test_additional_request_metadata (self ):
481
+ def test_additional_request_metadata (self , api_endpoint ):
475
482
aiplatform .init (request_metadata = [("foo" , "bar" )])
476
483
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
477
484
response = model .generate_content (
@@ -480,7 +487,7 @@ def test_additional_request_metadata(self):
480
487
)
481
488
assert response
482
489
483
- def test_compute_tokens_from_text (self ):
490
+ def test_compute_tokens_from_text (self , api_endpoint ):
484
491
model = generative_models .GenerativeModel (GEMINI_MODEL_NAME )
485
492
response = model .compute_tokens (["Why is sky blue?" , "Explain it like I'm 5." ])
486
493
assert len (response .tokens_info ) == 2
0 commit comments