Skip to content

Commit de55173

Browse files
speedstorm1copybara-github
authored andcommitted
chore: Parameterize tokenization end-to-end test with prod and staging API endpoints
PiperOrigin-RevId: 667657240
1 parent 71464e7 commit de55173

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

tests/system/vertexai/test_tokenization.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
#
1515

16+
import os
1617
import pytest
1718
import nltk
19+
1820
from nltk.corpus import udhr
1921
from google.cloud import aiplatform
2022
from vertexai.preview.tokenization import (
@@ -38,27 +40,46 @@
3840
for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB)
3941
]
4042

43+
STAGING_API_ENDPOINT = "STAGING_ENDPOINT"
44+
PROD_API_ENDPOINT = "PROD_ENDPOINT"
45+
4146

47+
@pytest.mark.parametrize(
48+
"api_endpoint_env_name", [STAGING_API_ENDPOINT, PROD_API_ENDPOINT]
49+
)
4250
class TestTokenization(e2e_base.TestEndToEnd):
51+
"""System tests for tokenization."""
4352

4453
_temp_prefix = "temp_tokenization_test_"
4554

46-
def setup_method(self):
55+
@pytest.fixture(scope="function", autouse=True)
56+
def setup_method(self, api_endpoint_env_name):
4757
super().setup_method()
4858
credentials, _ = auth.default(
4959
scopes=["https://www.googleapis.com/auth/cloud-platform"]
5060
)
61+
if api_endpoint_env_name == STAGING_API_ENDPOINT:
62+
api_endpoint = os.getenv(api_endpoint_env_name)
63+
else:
64+
api_endpoint = None
5165
aiplatform.init(
5266
project=e2e_base._PROJECT,
5367
location=e2e_base._LOCATION,
5468
credentials=credentials,
69+
api_endpoint=api_endpoint,
5570
)
5671

5772
@pytest.mark.parametrize(
5873
"model_name, corpus_name, corpus_lib",
5974
_MODEL_CORPUS_PARAMS,
6075
)
61-
def test_count_tokens_local(self, model_name, corpus_name, corpus_lib):
76+
def test_count_tokens_local(
77+
self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
78+
):
79+
# The Gemini 1.5 flash model requires the model version
80+
# number suffix (001) in staging only
81+
if api_endpoint_env_name == STAGING_API_ENDPOINT:
82+
model_name = model_name + "-001"
6283
tokenizer = get_tokenizer_for_model(model_name)
6384
model = GenerativeModel(model_name)
6485
nltk.download(corpus_name, quiet=True)
@@ -72,7 +93,13 @@ def test_count_tokens_local(self, model_name, corpus_name, corpus_lib):
7293
"model_name, corpus_name, corpus_lib",
7394
_MODEL_CORPUS_PARAMS,
7495
)
75-
def test_compute_tokens(self, model_name, corpus_name, corpus_lib):
96+
def test_compute_tokens(
97+
self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
98+
):
99+
# The Gemini 1.5 flash model requires the model version
100+
# number suffix (001) in staging only
101+
if api_endpoint_env_name == STAGING_API_ENDPOINT:
102+
model_name = model_name + "-001"
76103
tokenizer = get_tokenizer_for_model(model_name)
77104
model = GenerativeModel(model_name)
78105
nltk.download(corpus_name, quiet=True)

0 commit comments

Comments
 (0)