39
39
from google .cloud .aiplatform .compat .types import (
40
40
publisher_model as gca_publisher_model ,
41
41
)
42
+ import vertexai
42
43
from vertexai import vision_models as ga_vision_models
43
- from vertexai .preview import vision_models
44
+ from vertexai .preview import (
45
+ vision_models as preview_vision_models ,
46
+ )
44
47
45
48
from PIL import Image as PIL_Image
46
49
import pytest
@@ -121,12 +124,12 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
121
124
122
125
def generate_image_from_file (
123
126
width : int = 100 , height : int = 100
124
- ) -> vision_models .Image :
127
+ ) -> ga_vision_models .Image :
125
128
with tempfile .TemporaryDirectory () as temp_dir :
126
129
image_path = os .path .join (temp_dir , "image.png" )
127
130
pil_image = PIL_Image .new (mode = "RGB" , size = (width , height ))
128
131
pil_image .save (image_path , format = "PNG" )
129
- return vision_models .Image .load_from_file (image_path )
132
+ return ga_vision_models .Image .load_from_file (image_path )
130
133
131
134
132
135
@pytest .mark .usefixtures ("google_auth_mock" )
@@ -140,7 +143,7 @@ def setup_method(self):
140
143
def teardown_method (self ):
141
144
initializer .global_pool .shutdown (wait = True )
142
145
143
- def _get_image_generation_model (self ) -> vision_models .ImageGenerationModel :
146
+ def _get_image_generation_model (self ) -> preview_vision_models .ImageGenerationModel :
144
147
"""Gets the image generation model."""
145
148
aiplatform .init (
146
149
project = _TEST_PROJECT ,
@@ -153,7 +156,7 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
153
156
_IMAGE_GENERATION_PUBLISHER_MODEL_DICT
154
157
),
155
158
) as mock_get_publisher_model :
156
- model = vision_models .ImageGenerationModel .from_pretrained (
159
+ model = preview_vision_models .ImageGenerationModel .from_pretrained (
157
160
"imagegeneration@002"
158
161
)
159
162
@@ -164,13 +167,48 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
164
167
165
168
return model
166
169
170
+ def _get_preview_image_generation_model_top_level_from_pretrained (
171
+ self ,
172
+ ) -> preview_vision_models .ImageGenerationModel :
173
+ """Gets the image generation model from the top-level vertexai.preview.from_pretrained method."""
174
+ aiplatform .init (
175
+ project = _TEST_PROJECT ,
176
+ location = _TEST_LOCATION ,
177
+ )
178
+ with mock .patch .object (
179
+ target = model_garden_service_client .ModelGardenServiceClient ,
180
+ attribute = "get_publisher_model" ,
181
+ return_value = gca_publisher_model .PublisherModel (
182
+ _IMAGE_GENERATION_PUBLISHER_MODEL_DICT
183
+ ),
184
+ ) as mock_get_publisher_model :
185
+ model = vertexai .preview .from_pretrained (
186
+ foundation_model_name = "imagegeneration@002"
187
+ )
188
+
189
+ mock_get_publisher_model .assert_called_with (
190
+ name = "publishers/google/models/imagegeneration@002" ,
191
+ retry = base ._DEFAULT_RETRY ,
192
+ )
193
+
194
+ assert mock_get_publisher_model .call_count == 1
195
+
196
+ return model
197
+
167
198
def test_from_pretrained (self ):
168
199
model = self ._get_image_generation_model ()
169
200
assert (
170
201
model ._endpoint_name
171
202
== f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/imagegeneration@002"
172
203
)
173
204
205
+ def test_top_level_from_pretrained_preview (self ):
206
+ model = self ._get_preview_image_generation_model_top_level_from_pretrained ()
207
+ assert (
208
+ model ._endpoint_name
209
+ == f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/imagegeneration@002"
210
+ )
211
+
174
212
def test_generate_images (self ):
175
213
"""Tests the image generation model."""
176
214
model = self ._get_image_generation_model ()
@@ -238,7 +276,7 @@ def test_generate_images(self):
238
276
with tempfile .TemporaryDirectory () as temp_dir :
239
277
image_path = os .path .join (temp_dir , "image.png" )
240
278
image_response [0 ].save (location = image_path )
241
- image1 = vision_models .GeneratedImage .load_from_file (image_path )
279
+ image1 = preview_vision_models .GeneratedImage .load_from_file (image_path )
242
280
# assert image1._pil_image.size == (width, height)
243
281
assert image1 .generation_parameters
244
282
assert image1 .generation_parameters ["prompt" ] == prompt1
@@ -247,7 +285,7 @@ def test_generate_images(self):
247
285
mask_path = os .path .join (temp_dir , "mask.png" )
248
286
mask_pil_image = PIL_Image .new (mode = "RGB" , size = image1 ._pil_image .size )
249
287
mask_pil_image .save (mask_path , format = "PNG" )
250
- mask_image = vision_models .Image .load_from_file (mask_path )
288
+ mask_image = preview_vision_models .Image .load_from_file (mask_path )
251
289
252
290
# Test generating image from base image
253
291
with mock .patch .object (
@@ -408,7 +446,7 @@ def test_upscale_image_on_provided_image(self):
408
446
assert image_upscale_parameters ["mode" ] == "upscale"
409
447
410
448
assert upscaled_image ._image_bytes
411
- assert isinstance (upscaled_image , vision_models .GeneratedImage )
449
+ assert isinstance (upscaled_image , preview_vision_models .GeneratedImage )
412
450
413
451
def test_upscale_image_raises_if_not_1024x1024 (self ):
414
452
"""Tests image upscaling on generated images."""
@@ -457,7 +495,7 @@ def test_get_captions(self):
457
495
image_path = os .path .join (temp_dir , "image.png" )
458
496
pil_image = PIL_Image .new (mode = "RGB" , size = (100 , 100 ))
459
497
pil_image .save (image_path , format = "PNG" )
460
- image = vision_models .Image .load_from_file (image_path )
498
+ image = preview_vision_models .Image .load_from_file (image_path )
461
499
462
500
with mock .patch .object (
463
501
target = prediction_service_client .PredictionServiceClient ,
@@ -544,7 +582,7 @@ def test_image_embedding_model_with_only_image(self):
544
582
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
545
583
),
546
584
) as mock_get_publisher_model :
547
- model = vision_models .MultiModalEmbeddingModel .from_pretrained (
585
+ model = preview_vision_models .MultiModalEmbeddingModel .from_pretrained (
548
586
"multimodalembedding@001"
549
587
)
550
588
@@ -583,7 +621,7 @@ def test_image_embedding_model_with_image_and_text(self):
583
621
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
584
622
),
585
623
):
586
- model = vision_models .MultiModalEmbeddingModel .from_pretrained (
624
+ model = preview_vision_models .MultiModalEmbeddingModel .from_pretrained (
587
625
"multimodalembedding@001"
588
626
)
589
627
@@ -715,7 +753,7 @@ def test_get_captions(self):
715
753
image_path = os .path .join (temp_dir , "image.png" )
716
754
pil_image = PIL_Image .new (mode = "RGB" , size = (100 , 100 ))
717
755
pil_image .save (image_path , format = "PNG" )
718
- image = vision_models .Image .load_from_file (image_path )
756
+ image = preview_vision_models .Image .load_from_file (image_path )
719
757
720
758
with mock .patch .object (
721
759
target = prediction_service_client .PredictionServiceClient ,
0 commit comments