Skip to content

Commit f978200

Browse files
sararobcopybara-github
authored andcommitted
feat: add Model Garden support to vertexai.preview.from_pretrained
PiperOrigin-RevId: 569576160
1 parent 1aab6fd commit f978200

File tree

8 files changed

+581
-70
lines changed

8 files changed

+581
-70
lines changed

tests/system/aiplatform/test_language_models.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
from google.cloud.aiplatform.compat.types import (
2525
job_state as gca_job_state,
2626
)
27+
import vertexai
2728
from tests.system.aiplatform import e2e_base
2829
from google.cloud.aiplatform.utils import gcs_utils
2930
from vertexai import language_models
30-
from vertexai.preview import language_models as preview_language_models
31+
from vertexai.preview import (
32+
language_models as preview_language_models,
33+
)
3134
from vertexai.preview.language_models import (
3235
ChatModel,
3336
InputOutputTextPair,
@@ -87,6 +90,24 @@ def test_text_generation_streaming(self):
8790
):
8891
assert response.text
8992

93+
def test_preview_text_embedding_top_level_from_pretrained(self):
94+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
95+
96+
model = vertexai.preview.from_pretrained(
97+
foundation_model_name="google/text-bison@001"
98+
)
99+
100+
assert model.predict(
101+
"What is the best recipe for banana bread? Recipe:",
102+
max_output_tokens=128,
103+
temperature=0.0,
104+
top_p=1.0,
105+
top_k=5,
106+
stop_sequences=["# %%"],
107+
).text
108+
109+
assert isinstance(model, preview_language_models.TextEmbeddingModel)
110+
90111
def test_chat_on_chat_model(self):
91112
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
92113

tests/unit/aiplatform/test_language_models.py

+88
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
model as gca_model,
5959
)
6060

61+
import vertexai
6162
from vertexai.preview import (
6263
language_models as preview_language_models,
6364
)
@@ -2598,6 +2599,93 @@ def test_batch_prediction_for_text_embedding(self):
25982599
model_parameters={},
25992600
)
26002601

2602+
def test_text_generation_top_level_from_pretrained_preview(self):
2603+
"""Tests the text generation model."""
2604+
aiplatform.init(
2605+
project=_TEST_PROJECT,
2606+
location=_TEST_LOCATION,
2607+
)
2608+
with mock.patch.object(
2609+
target=model_garden_service_client.ModelGardenServiceClient,
2610+
attribute="get_publisher_model",
2611+
return_value=gca_publisher_model.PublisherModel(
2612+
_TEXT_BISON_PUBLISHER_MODEL_DICT
2613+
),
2614+
) as mock_get_publisher_model:
2615+
model = vertexai.preview.from_pretrained(
2616+
foundation_model_name="text-bison@001"
2617+
)
2618+
2619+
assert isinstance(model, preview_language_models.TextGenerationModel)
2620+
2621+
mock_get_publisher_model.assert_called_with(
2622+
name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY
2623+
)
2624+
assert mock_get_publisher_model.call_count == 1
2625+
2626+
assert (
2627+
model._model_resource_name
2628+
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001"
2629+
)
2630+
2631+
# Test that methods on TextGenerationModel still work as expected
2632+
gca_predict_response = gca_prediction_service.PredictResponse()
2633+
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)
2634+
2635+
with mock.patch.object(
2636+
target=prediction_service_client.PredictionServiceClient,
2637+
attribute="predict",
2638+
return_value=gca_predict_response,
2639+
):
2640+
response = model.predict(
2641+
"What is the best recipe for banana bread? Recipe:",
2642+
max_output_tokens=128,
2643+
temperature=0.0,
2644+
top_p=1.0,
2645+
top_k=5,
2646+
)
2647+
2648+
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
2649+
assert (
2650+
response.raw_prediction_response.predictions[0]
2651+
== _TEST_TEXT_GENERATION_PREDICTION
2652+
)
2653+
assert (
2654+
response.safety_attributes["Violent"]
2655+
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
2656+
)
2657+
2658+
def test_text_embedding_top_level_from_pretrained_preview(self):
2659+
"""Tests the text embedding model."""
2660+
aiplatform.init(
2661+
project=_TEST_PROJECT,
2662+
location=_TEST_LOCATION,
2663+
)
2664+
with mock.patch.object(
2665+
target=model_garden_service_client.ModelGardenServiceClient,
2666+
attribute="get_publisher_model",
2667+
return_value=gca_publisher_model.PublisherModel(
2668+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
2669+
),
2670+
) as mock_get_publisher_model:
2671+
model = vertexai.preview.from_pretrained(
2672+
foundation_model_name="textembedding-gecko@001"
2673+
)
2674+
2675+
assert isinstance(model, preview_language_models.TextEmbeddingModel)
2676+
2677+
assert (
2678+
model._endpoint_name
2679+
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001"
2680+
)
2681+
2682+
mock_get_publisher_model.assert_called_with(
2683+
name="publishers/google/models/textembedding-gecko@001",
2684+
retry=base._DEFAULT_RETRY,
2685+
)
2686+
2687+
assert mock_get_publisher_model.call_count == 1
2688+
26012689

26022690
# TODO (b/285946649): add more test coverage before public preview release
26032691
@pytest.mark.usefixtures("google_auth_mock")

tests/unit/aiplatform/test_vision_models.py

+50-12
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@
3939
from google.cloud.aiplatform.compat.types import (
4040
publisher_model as gca_publisher_model,
4141
)
42+
import vertexai
4243
from vertexai import vision_models as ga_vision_models
43-
from vertexai.preview import vision_models
44+
from vertexai.preview import (
45+
vision_models as preview_vision_models,
46+
)
4447

4548
from PIL import Image as PIL_Image
4649
import pytest
@@ -121,12 +124,12 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
121124

122125
def generate_image_from_file(
123126
width: int = 100, height: int = 100
124-
) -> vision_models.Image:
127+
) -> ga_vision_models.Image:
125128
with tempfile.TemporaryDirectory() as temp_dir:
126129
image_path = os.path.join(temp_dir, "image.png")
127130
pil_image = PIL_Image.new(mode="RGB", size=(width, height))
128131
pil_image.save(image_path, format="PNG")
129-
return vision_models.Image.load_from_file(image_path)
132+
return ga_vision_models.Image.load_from_file(image_path)
130133

131134

132135
@pytest.mark.usefixtures("google_auth_mock")
@@ -140,7 +143,7 @@ def setup_method(self):
140143
def teardown_method(self):
141144
initializer.global_pool.shutdown(wait=True)
142145

143-
def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
146+
def _get_image_generation_model(self) -> preview_vision_models.ImageGenerationModel:
144147
"""Gets the image generation model."""
145148
aiplatform.init(
146149
project=_TEST_PROJECT,
@@ -153,7 +156,7 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
153156
_IMAGE_GENERATION_PUBLISHER_MODEL_DICT
154157
),
155158
) as mock_get_publisher_model:
156-
model = vision_models.ImageGenerationModel.from_pretrained(
159+
model = preview_vision_models.ImageGenerationModel.from_pretrained(
157160
"imagegeneration@002"
158161
)
159162

@@ -164,13 +167,48 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
164167

165168
return model
166169

170+
def _get_preview_image_generation_model_top_level_from_pretrained(
171+
self,
172+
) -> preview_vision_models.ImageGenerationModel:
173+
"""Gets the image generation model from the top-level vertexai.preview.from_pretrained method."""
174+
aiplatform.init(
175+
project=_TEST_PROJECT,
176+
location=_TEST_LOCATION,
177+
)
178+
with mock.patch.object(
179+
target=model_garden_service_client.ModelGardenServiceClient,
180+
attribute="get_publisher_model",
181+
return_value=gca_publisher_model.PublisherModel(
182+
_IMAGE_GENERATION_PUBLISHER_MODEL_DICT
183+
),
184+
) as mock_get_publisher_model:
185+
model = vertexai.preview.from_pretrained(
186+
foundation_model_name="imagegeneration@002"
187+
)
188+
189+
mock_get_publisher_model.assert_called_with(
190+
name="publishers/google/models/imagegeneration@002",
191+
retry=base._DEFAULT_RETRY,
192+
)
193+
194+
assert mock_get_publisher_model.call_count == 1
195+
196+
return model
197+
167198
def test_from_pretrained(self):
168199
model = self._get_image_generation_model()
169200
assert (
170201
model._endpoint_name
171202
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002"
172203
)
173204

205+
def test_top_level_from_pretrained_preview(self):
206+
model = self._get_preview_image_generation_model_top_level_from_pretrained()
207+
assert (
208+
model._endpoint_name
209+
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002"
210+
)
211+
174212
def test_generate_images(self):
175213
"""Tests the image generation model."""
176214
model = self._get_image_generation_model()
@@ -238,7 +276,7 @@ def test_generate_images(self):
238276
with tempfile.TemporaryDirectory() as temp_dir:
239277
image_path = os.path.join(temp_dir, "image.png")
240278
image_response[0].save(location=image_path)
241-
image1 = vision_models.GeneratedImage.load_from_file(image_path)
279+
image1 = preview_vision_models.GeneratedImage.load_from_file(image_path)
242280
# assert image1._pil_image.size == (width, height)
243281
assert image1.generation_parameters
244282
assert image1.generation_parameters["prompt"] == prompt1
@@ -247,7 +285,7 @@ def test_generate_images(self):
247285
mask_path = os.path.join(temp_dir, "mask.png")
248286
mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
249287
mask_pil_image.save(mask_path, format="PNG")
250-
mask_image = vision_models.Image.load_from_file(mask_path)
288+
mask_image = preview_vision_models.Image.load_from_file(mask_path)
251289

252290
# Test generating image from base image
253291
with mock.patch.object(
@@ -408,7 +446,7 @@ def test_upscale_image_on_provided_image(self):
408446
assert image_upscale_parameters["mode"] == "upscale"
409447

410448
assert upscaled_image._image_bytes
411-
assert isinstance(upscaled_image, vision_models.GeneratedImage)
449+
assert isinstance(upscaled_image, preview_vision_models.GeneratedImage)
412450

413451
def test_upscale_image_raises_if_not_1024x1024(self):
414452
"""Tests image upscaling on generated images."""
@@ -457,7 +495,7 @@ def test_get_captions(self):
457495
image_path = os.path.join(temp_dir, "image.png")
458496
pil_image = PIL_Image.new(mode="RGB", size=(100, 100))
459497
pil_image.save(image_path, format="PNG")
460-
image = vision_models.Image.load_from_file(image_path)
498+
image = preview_vision_models.Image.load_from_file(image_path)
461499

462500
with mock.patch.object(
463501
target=prediction_service_client.PredictionServiceClient,
@@ -544,7 +582,7 @@ def test_image_embedding_model_with_only_image(self):
544582
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
545583
),
546584
) as mock_get_publisher_model:
547-
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
585+
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
548586
"multimodalembedding@001"
549587
)
550588

@@ -583,7 +621,7 @@ def test_image_embedding_model_with_image_and_text(self):
583621
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
584622
),
585623
):
586-
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
624+
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
587625
"multimodalembedding@001"
588626
)
589627

@@ -715,7 +753,7 @@ def test_get_captions(self):
715753
image_path = os.path.join(temp_dir, "image.png")
716754
pil_image = PIL_Image.new(mode="RGB", size=(100, 100))
717755
pil_image.save(image_path, format="PNG")
718-
image = vision_models.Image.load_from_file(image_path)
756+
image = preview_vision_models.Image.load_from_file(image_path)
719757

720758
with mock.patch.object(
721759
target=prediction_service_client.PredictionServiceClient,

0 commit comments

Comments
 (0)