Skip to content

Commit bf33fb3

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: GenAI - Fixed the GenerativeModel's handling of tuned models from different region
PiperOrigin-RevId: 622336161
1 parent 0f0b677 commit bf33fb3

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

tests/unit/vertexai/test_generative_models.py

+46
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636

3737

3838
_TEST_PROJECT = "test-project"
39+
_TEST_PROJECT2 = "test-project2"
3940
_TEST_LOCATION = "us-central1"
41+
_TEST_LOCATION2 = "europe-west4"
4042

4143

4244
_RESPONSE_TEXT_PART_STRUCT = {
@@ -283,6 +285,50 @@ def setup_method(self):
283285
def teardown_method(self):
284286
initializer.global_pool.shutdown(wait=True)
285287

288+
@mock.patch.object(
289+
target=prediction_service.PredictionServiceClient,
290+
attribute="generate_content",
291+
new=mock_generate_content,
292+
)
293+
@pytest.mark.parametrize(
294+
"generative_models",
295+
[generative_models, preview_generative_models],
296+
)
297+
def test_generative_model_constructor_model_name(
298+
self, generative_models: generative_models
299+
):
300+
project_location_prefix = (
301+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
302+
)
303+
304+
model_name1 = "gemini-pro"
305+
model1 = generative_models.GenerativeModel(model_name1)
306+
assert (
307+
model1._prediction_resource_name
308+
== project_location_prefix + "publishers/google/models/" + model_name1
309+
)
310+
311+
model_name2 = "models/gemini-pro"
312+
model2 = generative_models.GenerativeModel(model_name2)
313+
assert (
314+
model2._prediction_resource_name
315+
== project_location_prefix + "publishers/google/" + model_name2
316+
)
317+
318+
model_name3 = "publishers/some_publisher/models/some_model"
319+
model3 = generative_models.GenerativeModel(model_name3)
320+
assert model3._prediction_resource_name == project_location_prefix + model_name3
321+
322+
model_name4 = (
323+
f"projects/{_TEST_PROJECT2}/locations/{_TEST_LOCATION2}/endpoints/endpoint1"
324+
)
325+
model4 = generative_models.GenerativeModel(model_name4)
326+
assert model4._prediction_resource_name == model_name4
327+
assert _TEST_LOCATION2 in model4._prediction_client._api_endpoint
328+
329+
with pytest.raises(ValueError):
330+
generative_models.GenerativeModel("foo/bar/models/gemini-pro")
331+
286332
@mock.patch.object(
287333
target=prediction_service.PredictionServiceClient,
288334
attribute="generate_content",

vertexai/generative_models/_generative_models.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333

3434
from google.cloud.aiplatform import initializer as aiplatform_initializer
35+
from google.cloud.aiplatform import utils as aiplatform_utils
3536
from google.cloud.aiplatform_v1beta1 import types as aiplatform_types
3637
from google.cloud.aiplatform_v1beta1.services import prediction_service
3738
from google.cloud.aiplatform_v1beta1.types import (
@@ -169,11 +170,20 @@ def __init__(
169170
prediction_resource_name = (
170171
f"projects/{project}/locations/{location}/{model_name}"
171172
)
172-
else:
173+
elif model_name.startswith("projects/"):
173174
prediction_resource_name = model_name
175+
else:
176+
raise ValueError(
177+
"model_name must be either a Model Garden model ID or a full resource name."
178+
)
179+
180+
location = aiplatform_utils.extract_project_and_location_from_parent(
181+
prediction_resource_name
182+
)["location"]
174183

175184
self._model_name = model_name
176185
self._prediction_resource_name = prediction_resource_name
186+
self._location = location
177187
self._generation_config = generation_config
178188
self._safety_settings = safety_settings
179189
self._tools = tools
@@ -197,6 +207,7 @@ def _prediction_client(self) -> prediction_service.PredictionServiceClient:
197207
self._prediction_client_value = (
198208
aiplatform_initializer.global_config.create_client(
199209
client_class=prediction_service.PredictionServiceClient,
210+
location_override=self._location,
200211
prediction_client=True,
201212
)
202213
)
@@ -211,6 +222,7 @@ def _prediction_async_client(
211222
self._prediction_async_client_value = (
212223
aiplatform_initializer.global_config.create_client(
213224
client_class=prediction_service.PredictionServiceAsyncClient,
225+
location_override=self._location,
214226
prediction_client=True,
215227
)
216228
)

0 commit comments

Comments
 (0)