Skip to content

Commit ebbd1bf

Browse files
speedstorm1copybara-github
authored andcommitted
chore: Parameterize generative model end-to-end test with prod and staging API endpoints, setup and fetch staging endpoint environment variable
PiperOrigin-RevId: 661454748
1 parent a076191 commit ebbd1bf

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

.kokoro/build.sh

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/service-account.json
3333
# Setup project id.
3434
export PROJECT_ID=$(cat "${KOKORO_GFILE_DIR}/project-id.json")
3535

36+
# Setup staging endpoint.
37+
export STAGING_ENDPOINT=$(cat "${KOKORO_KEYSTORE_DIR}/73713_vertexai-staging-endpoint")
38+
3639
# Remove old nox
3740
python3 -m pip uninstall --yes --quiet nox-automation
3841

.kokoro/continuous/common.cfg

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/google-cloud-python"
1616
# Use the trampoline script to run in docker.
1717
build_file: "python-aiplatform/.kokoro/trampoline.sh"
1818

19+
# Fetch vertexai staging endpoint
20+
before_action {
21+
fetch_keystore {
22+
keystore_resource {
23+
keystore_config_id: 73713
24+
keyname: "vertexai-staging-endpoint"
25+
}
26+
}
27+
}
28+
1929
# Configure the docker image for kokoro-trampoline.
2030
env_vars: {
2131
key: "TRAMPOLINE_IMAGE"

tests/system/vertexai/test_generative_models.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""System tests for generative models."""
1919

2020
import json
21+
import os
2122
import pytest
2223

2324
# Google imports
@@ -36,6 +37,9 @@
3637
GEMINI_15_MODEL_NAME = "gemini-1.5-pro-preview-0409"
3738
GEMINI_15_PRO_MODEL_NAME = "gemini-1.5-pro-001"
3839

40+
STAGING_API_ENDPOINT = os.getenv("STAGING_ENDPOINT")
41+
PROD_API_ENDPOINT = None
42+
3943

4044
# A dummy function for function calling
4145
def get_current_weather(location: str, unit: str = "centigrade"):
@@ -84,12 +88,14 @@ def get_current_weather(location: str, unit: str = "centigrade"):
8488
}
8589

8690

91+
@pytest.mark.parametrize("api_endpoint", [STAGING_API_ENDPOINT, PROD_API_ENDPOINT])
8792
class TestGenerativeModels(e2e_base.TestEndToEnd):
8893
"""System tests for generative models."""
8994

9095
_temp_prefix = "temp_generative_models_test_"
9196

92-
def setup_method(self):
97+
@pytest.fixture(scope="function", autouse=True)
98+
def setup_method(self, api_endpoint):
9399
super().setup_method()
94100
credentials, _ = auth.default(
95101
scopes=["https://www.googleapis.com/auth/cloud-platform"]
@@ -98,9 +104,10 @@ def setup_method(self):
98104
project=e2e_base._PROJECT,
99105
location=e2e_base._LOCATION,
100106
credentials=credentials,
107+
api_endpoint=api_endpoint,
101108
)
102109

103-
def test_generate_content_with_cached_content_from_text(self):
110+
def test_generate_content_with_cached_content_from_text(self, api_endpoint):
104111
cached_content = caching.CachedContent.create(
105112
model_name=GEMINI_15_PRO_MODEL_NAME,
106113
system_instruction="Please answer all the questions like a pirate.",
@@ -138,7 +145,7 @@ def test_generate_content_with_cached_content_from_text(self):
138145
finally:
139146
cached_content.delete()
140147

141-
def test_generate_content_from_text(self):
148+
def test_generate_content_from_text(self, api_endpoint):
142149
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
143150
response = model.generate_content(
144151
"Why is sky blue?",
@@ -147,15 +154,15 @@ def test_generate_content_from_text(self):
147154
assert response.text
148155

149156
@pytest.mark.asyncio
150-
async def test_generate_content_async(self):
157+
async def test_generate_content_async(self, api_endpoint):
151158
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
152159
response = await model.generate_content_async(
153160
"Why is sky blue?",
154161
generation_config=generative_models.GenerationConfig(temperature=0),
155162
)
156163
assert response.text
157164

158-
def test_generate_content_streaming(self):
165+
def test_generate_content_streaming(self, api_endpoint):
159166
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
160167
stream = model.generate_content(
161168
"Why is sky blue?",
@@ -170,7 +177,7 @@ def test_generate_content_streaming(self):
170177
)
171178

172179
@pytest.mark.asyncio
173-
async def test_generate_content_streaming_async(self):
180+
async def test_generate_content_streaming_async(self, api_endpoint):
174181
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
175182
async_stream = await model.generate_content_async(
176183
"Why is sky blue?",
@@ -184,7 +191,7 @@ async def test_generate_content_streaming_async(self):
184191
is generative_models.FinishReason.STOP
185192
)
186193

187-
def test_generate_content_with_parameters(self):
194+
def test_generate_content_with_parameters(self, api_endpoint):
188195
model = generative_models.GenerativeModel(
189196
GEMINI_MODEL_NAME,
190197
system_instruction=[
@@ -211,7 +218,7 @@ def test_generate_content_with_parameters(self):
211218
)
212219
assert response.text
213220

214-
def test_generate_content_with_gemini_15_parameters(self):
221+
def test_generate_content_with_gemini_15_parameters(self, api_endpoint):
215222
model = generative_models.GenerativeModel(GEMINI_15_MODEL_NAME)
216223
response = model.generate_content(
217224
contents="Why is sky blue? Respond in JSON Format.",
@@ -237,7 +244,7 @@ def test_generate_content_with_gemini_15_parameters(self):
237244
assert response.text
238245
assert json.loads(response.text)
239246

240-
def test_generate_content_from_list_of_content_dict(self):
247+
def test_generate_content_from_list_of_content_dict(self, api_endpoint):
241248
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
242249
response = model.generate_content(
243250
contents=[{"role": "user", "parts": [{"text": "Why is sky blue?"}]}],
@@ -248,7 +255,7 @@ def test_generate_content_from_list_of_content_dict(self):
248255
@pytest.mark.skip(
249256
reason="Breaking change in the gemini-pro-vision model. See b/315803556#comment3"
250257
)
251-
def test_generate_content_from_remote_image(self):
258+
def test_generate_content_from_remote_image(self, api_endpoint):
252259
vision_model = generative_models.GenerativeModel(GEMINI_VISION_MODEL_NAME)
253260
image_part = generative_models.Part.from_uri(
254261
uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg",
@@ -261,7 +268,7 @@ def test_generate_content_from_remote_image(self):
261268
assert response.text
262269
assert "cat" in response.text
263270

264-
def test_generate_content_from_text_and_remote_image(self):
271+
def test_generate_content_from_text_and_remote_image(self, api_endpoint):
265272
vision_model = generative_models.GenerativeModel(GEMINI_VISION_MODEL_NAME)
266273
image_part = generative_models.Part.from_uri(
267274
uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg",
@@ -274,7 +281,7 @@ def test_generate_content_from_text_and_remote_image(self):
274281
assert response.text
275282
assert "cat" in response.text
276283

277-
def test_generate_content_from_text_and_remote_video(self):
284+
def test_generate_content_from_text_and_remote_video(self, api_endpoint):
278285
vision_model = generative_models.GenerativeModel(GEMINI_VISION_MODEL_NAME)
279286
video_part = generative_models.Part.from_uri(
280287
uri="gs://cloud-samples-data/video/animals.mp4",
@@ -287,7 +294,7 @@ def test_generate_content_from_text_and_remote_video(self):
287294
assert response.text
288295
assert "Zootopia" in response.text
289296

290-
def test_grounding_google_search_retriever(self):
297+
def test_grounding_google_search_retriever(self, api_endpoint):
291298
model = preview_generative_models.GenerativeModel(GEMINI_MODEL_NAME)
292299
google_search_retriever_tool = (
293300
preview_generative_models.Tool.from_google_search_retrieval(
@@ -309,7 +316,7 @@ def test_grounding_google_search_retriever(self):
309316

310317
# Chat
311318

312-
def test_send_message_from_text(self):
319+
def test_send_message_from_text(self, api_endpoint):
313320
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
314321
chat = model.start_chat()
315322
response1 = chat.send_message(
@@ -326,7 +333,7 @@ def test_send_message_from_text(self):
326333
assert response2.text
327334
assert len(chat.history) == 4
328335

329-
def test_chat_function_calling(self):
336+
def test_chat_function_calling(self, api_endpoint):
330337
get_current_weather_func = generative_models.FunctionDeclaration(
331338
name="get_current_weather",
332339
description="Get the current weather in a given location",
@@ -360,7 +367,7 @@ def test_chat_function_calling(self):
360367
)
361368
assert response2.text
362369

363-
def test_generate_content_function_calling(self):
370+
def test_generate_content_function_calling(self, api_endpoint):
364371
get_current_weather_func = generative_models.FunctionDeclaration(
365372
name="get_current_weather",
366373
description="Get the current weather in a given location",
@@ -440,7 +447,7 @@ def test_generate_content_function_calling(self):
440447

441448
assert summary
442449

443-
def test_chat_automatic_function_calling(self):
450+
def test_chat_automatic_function_calling(self, api_endpoint):
444451
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
445452
get_current_weather
446453
)
@@ -471,7 +478,7 @@ def test_chat_automatic_function_calling(self):
471478
assert chat.history[-2].parts[0].function_response
472479
assert chat.history[-2].parts[0].function_response.name == "get_current_weather"
473480

474-
def test_additional_request_metadata(self):
481+
def test_additional_request_metadata(self, api_endpoint):
475482
aiplatform.init(request_metadata=[("foo", "bar")])
476483
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
477484
response = model.generate_content(
@@ -480,7 +487,7 @@ def test_additional_request_metadata(self):
480487
)
481488
assert response
482489

483-
def test_compute_tokens_from_text(self):
490+
def test_compute_tokens_from_text(self, api_endpoint):
484491
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
485492
response = model.compute_tokens(["Why is sky blue?", "Explain it like I'm 5."])
486493
assert len(response.tokens_info) == 2

0 commit comments

Comments
 (0)