13
13
# limitations under the License.
14
14
#
15
15
16
+ import os
16
17
import pytest
17
18
import nltk
19
+
18
20
from nltk .corpus import udhr
19
21
from google .cloud import aiplatform
20
22
from vertexai .preview .tokenization import (
38
40
for (corpus_name , corpus_lib ) in zip (_CORPUS , _CORPUS_LIB )
39
41
]
40
42
43
+ STAGING_API_ENDPOINT = "STAGING_ENDPOINT"
44
+ PROD_API_ENDPOINT = "PROD_ENDPOINT"
45
+
41
46
47
+ @pytest .mark .parametrize (
48
+ "api_endpoint_env_name" , [STAGING_API_ENDPOINT , PROD_API_ENDPOINT ]
49
+ )
42
50
class TestTokenization (e2e_base .TestEndToEnd ):
51
+ """System tests for tokenization."""
43
52
44
53
_temp_prefix = "temp_tokenization_test_"
45
54
46
- def setup_method (self ):
55
+ @pytest .fixture (scope = "function" , autouse = True )
56
+ def setup_method (self , api_endpoint_env_name ):
47
57
super ().setup_method ()
48
58
credentials , _ = auth .default (
49
59
scopes = ["https://www.googleapis.com/auth/cloud-platform" ]
50
60
)
61
+ if api_endpoint_env_name == STAGING_API_ENDPOINT :
62
+ api_endpoint = os .getenv (api_endpoint_env_name )
63
+ else :
64
+ api_endpoint = None
51
65
aiplatform .init (
52
66
project = e2e_base ._PROJECT ,
53
67
location = e2e_base ._LOCATION ,
54
68
credentials = credentials ,
69
+ api_endpoint = api_endpoint ,
55
70
)
56
71
57
72
@pytest .mark .parametrize (
58
73
"model_name, corpus_name, corpus_lib" ,
59
74
_MODEL_CORPUS_PARAMS ,
60
75
)
61
- def test_count_tokens_local (self , model_name , corpus_name , corpus_lib ):
76
+ def test_count_tokens_local (
77
+ self , model_name , corpus_name , corpus_lib , api_endpoint_env_name
78
+ ):
79
+ # The Gemini 1.5 flash model requires the model version
80
+ # number suffix (001) in staging only
81
+ if api_endpoint_env_name == STAGING_API_ENDPOINT :
82
+ model_name = model_name + "-001"
62
83
tokenizer = get_tokenizer_for_model (model_name )
63
84
model = GenerativeModel (model_name )
64
85
nltk .download (corpus_name , quiet = True )
@@ -72,7 +93,13 @@ def test_count_tokens_local(self, model_name, corpus_name, corpus_lib):
72
93
"model_name, corpus_name, corpus_lib" ,
73
94
_MODEL_CORPUS_PARAMS ,
74
95
)
75
- def test_compute_tokens (self , model_name , corpus_name , corpus_lib ):
96
+ def test_compute_tokens (
97
+ self , model_name , corpus_name , corpus_lib , api_endpoint_env_name
98
+ ):
99
+ # The Gemini 1.5 flash model requires the model version
100
+ # number suffix (001) in staging only
101
+ if api_endpoint_env_name == STAGING_API_ENDPOINT :
102
+ model_name = model_name + "-001"
76
103
tokenizer = get_tokenizer_for_model (model_name )
77
104
model = GenerativeModel (model_name )
78
105
nltk .download (corpus_name , quiet = True )
0 commit comments