23
23
import os
24
24
import tempfile
25
25
from typing import Any , Dict
26
+ import unittest
26
27
from unittest import mock
27
28
28
29
from google .cloud import aiplatform
@@ -175,7 +176,9 @@ def test_generate_images(self):
175
176
model = self ._get_image_generation_model ()
176
177
177
178
width = 1024
178
- height = 768
179
+ # TODO(b/295946075) The service stopped supporting image sizes.
180
+ # height = 768
181
+ height = 1024
179
182
number_of_images = 4
180
183
seed = 1
181
184
guidance_scale = 15
@@ -200,8 +203,9 @@ def test_generate_images(self):
200
203
# Optional:
201
204
negative_prompt = negative_prompt1 ,
202
205
number_of_images = number_of_images ,
203
- width = width ,
204
- height = height ,
206
+ # TODO(b/295946075) The service stopped supporting image sizes.
207
+ # width=width,
208
+ # height=height,
205
209
seed = seed ,
206
210
guidance_scale = guidance_scale ,
207
211
)
@@ -210,8 +214,9 @@ def test_generate_images(self):
210
214
actual_instance = predict_kwargs ["instances" ][0 ]
211
215
assert actual_instance ["prompt" ] == prompt1
212
216
assert actual_instance ["negativePrompt" ] == negative_prompt1
213
- assert actual_parameters ["sampleImageSize" ] == str (max (width , height ))
214
- assert actual_parameters ["aspectRatio" ] == f"{ width } :{ height } "
217
+ # TODO(b/295946075) The service stopped supporting image sizes.
218
+ # assert actual_parameters["sampleImageSize"] == str(max(width, height))
219
+ # assert actual_parameters["aspectRatio"] == f"{width}:{height}"
215
220
assert actual_parameters ["seed" ] == seed
216
221
assert actual_parameters ["guidanceScale" ] == guidance_scale
217
222
@@ -221,8 +226,9 @@ def test_generate_images(self):
221
226
assert image .generation_parameters
222
227
assert image .generation_parameters ["prompt" ] == prompt1
223
228
assert image .generation_parameters ["negative_prompt" ] == negative_prompt1
224
- assert image .generation_parameters ["width" ] == width
225
- assert image .generation_parameters ["height" ] == height
229
+ # TODO(b/295946075) The service stopped supporting image sizes.
230
+ # assert image.generation_parameters["width"] == width
231
+ # assert image.generation_parameters["height"] == height
226
232
assert image .generation_parameters ["seed" ] == seed
227
233
assert image .generation_parameters ["guidance_scale" ] == guidance_scale
228
234
assert image .generation_parameters ["index_of_image_in_batch" ] == idx
@@ -233,13 +239,13 @@ def test_generate_images(self):
233
239
image_path = os .path .join (temp_dir , "image.png" )
234
240
image_response [0 ].save (location = image_path )
235
241
image1 = vision_models .GeneratedImage .load_from_file (image_path )
236
- assert image1 ._pil_image .size == (width , height )
242
+ # assert image1._pil_image.size == (width, height)
237
243
assert image1 .generation_parameters
238
244
assert image1 .generation_parameters ["prompt" ] == prompt1
239
245
240
246
# Preparing mask
241
247
mask_path = os .path .join (temp_dir , "mask.png" )
242
- mask_pil_image = PIL_Image .new (mode = "RGB" , size = ( width , height ) )
248
+ mask_pil_image = PIL_Image .new (mode = "RGB" , size = image1 . _pil_image . size )
243
249
mask_pil_image .save (mask_path , format = "PNG" )
244
250
mask_image = vision_models .Image .load_from_file (mask_path )
245
251
@@ -273,6 +279,7 @@ def test_generate_images(self):
273
279
assert image .generation_parameters ["base_image_hash" ]
274
280
assert image .generation_parameters ["mask_hash" ]
275
281
282
+ @unittest .skip (reason = "b/295946075 The service stopped supporting image sizes." )
276
283
def test_generate_images_requests_square_images_by_default (self ):
277
284
"""Tests that the model class generates square image by default."""
278
285
model = self ._get_image_generation_model ()
0 commit comments