|
36 | 36 |
|
37 | 37 |
|
38 | 38 | _TEST_PROJECT = "test-project"
|
| 39 | +_TEST_PROJECT2 = "test-project2" |
39 | 40 | _TEST_LOCATION = "us-central1"
|
| 41 | +_TEST_LOCATION2 = "europe-west4" |
40 | 42 |
|
41 | 43 |
|
42 | 44 | _RESPONSE_TEXT_PART_STRUCT = {
|
@@ -283,6 +285,50 @@ def setup_method(self):
|
283 | 285 | def teardown_method(self):
|
284 | 286 | initializer.global_pool.shutdown(wait=True)
|
285 | 287 |
|
| 288 | + @mock.patch.object( |
| 289 | + target=prediction_service.PredictionServiceClient, |
| 290 | + attribute="generate_content", |
| 291 | + new=mock_generate_content, |
| 292 | + ) |
| 293 | + @pytest.mark.parametrize( |
| 294 | + "generative_models", |
| 295 | + [generative_models, preview_generative_models], |
| 296 | + ) |
| 297 | + def test_generative_model_constructor_model_name( |
| 298 | + self, generative_models: generative_models |
| 299 | + ): |
| 300 | + project_location_prefix = ( |
| 301 | + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/" |
| 302 | + ) |
| 303 | + |
| 304 | + model_name1 = "gemini-pro" |
| 305 | + model1 = generative_models.GenerativeModel(model_name1) |
| 306 | + assert ( |
| 307 | + model1._prediction_resource_name |
| 308 | + == project_location_prefix + "publishers/google/models/" + model_name1 |
| 309 | + ) |
| 310 | + |
| 311 | + model_name2 = "models/gemini-pro" |
| 312 | + model2 = generative_models.GenerativeModel(model_name2) |
| 313 | + assert ( |
| 314 | + model2._prediction_resource_name |
| 315 | + == project_location_prefix + "publishers/google/" + model_name2 |
| 316 | + ) |
| 317 | + |
| 318 | + model_name3 = "publishers/some_publisher/models/some_model" |
| 319 | + model3 = generative_models.GenerativeModel(model_name3) |
| 320 | + assert model3._prediction_resource_name == project_location_prefix + model_name3 |
| 321 | + |
| 322 | + model_name4 = ( |
| 323 | + f"projects/{_TEST_PROJECT2}/locations/{_TEST_LOCATION2}/endpoints/endpoint1" |
| 324 | + ) |
| 325 | + model4 = generative_models.GenerativeModel(model_name4) |
| 326 | + assert model4._prediction_resource_name == model_name4 |
| 327 | + assert _TEST_LOCATION2 in model4._prediction_client._api_endpoint |
| 328 | + |
| 329 | + with pytest.raises(ValueError): |
| 330 | + generative_models.GenerativeModel("foo/bar/models/gemini-pro") |
| 331 | + |
286 | 332 | @mock.patch.object(
|
287 | 333 | target=prediction_service.PredictionServiceClient,
|
288 | 334 | attribute="generate_content",
|
|
0 commit comments