Skip to content

Commit d689331

Browse files
Zhenyi Qicopybara-github
Zhenyi Qi
authored andcommitted
fix: ensure model starts with publishers/ when users provide resource path from models/
PiperOrigin-RevId: 640914707
1 parent bd4c09c commit d689331

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

tests/unit/vertexai/test_generative_models.py

+4
Original file line numberDiff line numberDiff line change
@@ -422,24 +422,28 @@ def test_generative_model_constructor_model_name(
422422
model1._prediction_resource_name
423423
== project_location_prefix + "publishers/google/models/" + model_name1
424424
)
425+
assert model1._model_name == "publishers/google/models/gemini-pro"
425426

426427
model_name2 = "models/gemini-pro"
427428
model2 = generative_models.GenerativeModel(model_name2)
428429
assert (
429430
model2._prediction_resource_name
430431
== project_location_prefix + "publishers/google/" + model_name2
431432
)
433+
assert model2._model_name == "publishers/google/models/gemini-pro"
432434

433435
model_name3 = "publishers/some_publisher/models/some_model"
434436
model3 = generative_models.GenerativeModel(model_name3)
435437
assert model3._prediction_resource_name == project_location_prefix + model_name3
438+
assert model3._model_name == "publishers/some_publisher/models/some_model"
436439

437440
model_name4 = (
438441
f"projects/{_TEST_PROJECT2}/locations/{_TEST_LOCATION2}/endpoints/endpoint1"
439442
)
440443
model4 = generative_models.GenerativeModel(model_name4)
441444
assert model4._prediction_resource_name == model_name4
442445
assert _TEST_LOCATION2 in model4._prediction_client._api_endpoint
446+
assert model4._model_name == model_name4
443447

444448
with pytest.raises(ValueError):
445449
generative_models.GenerativeModel("foo/bar/models/gemini-pro")

vertexai/generative_models/_generative_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _reconcile_model_name(model_name: str, project: str, location: str) -> str:
110110
if "/" not in model_name:
111111
return f"publishers/google/models/{model_name}"
112112
elif model_name.startswith("models/"):
113-
return f"projects/{project}/locations/{location}/publishers/google/{model_name}"
113+
return f"publishers/google/{model_name}"
114114
elif model_name.startswith("publishers/") or model_name.startswith("projects/"):
115115
return model_name
116116
else:

0 commit comments

Comments
 (0)