Skip to content

Commit 52897e6

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LVM - Removed the width and height parameters from ImageGenerationModel.generate_images since the service has dropped support for image sizes and aspect ratios
PiperOrigin-RevId: 558246815
1 parent ce60cf7 commit 52897e6

File tree

3 files changed

+37
-27
lines changed

3 files changed

+37
-27
lines changed

tests/system/aiplatform/test_vision_models.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ def test_image_generation_model_generate_images(self):
9191
"imagegeneration@001"
9292
)
9393

94-
width = 1024
95-
height = 768
94+
# TODO(b/295946075): The service stopped supporting image sizes.
95+
# width = 1024
96+
# height = 768
9697
number_of_images = 4
9798
seed = 1
9899
guidance_scale = 15
@@ -104,20 +105,23 @@ def test_image_generation_model_generate_images(self):
104105
# Optional:
105106
negative_prompt=negative_prompt1,
106107
number_of_images=number_of_images,
107-
width=width,
108-
height=height,
108+
# TODO(b/295946075): The service stopped supporting image sizes.
109+
# width=width,
110+
# height=height,
109111
seed=seed,
110112
guidance_scale=guidance_scale,
111113
)
112114

113115
assert len(image_response.images) == number_of_images
114116
for idx, image in enumerate(image_response):
115-
assert image._pil_image.size == (width, height)
117+
# TODO(b/295946075): The service stopped supporting image sizes.
118+
# assert image._pil_image.size == (width, height)
116119
assert image.generation_parameters
117120
assert image.generation_parameters["prompt"] == prompt1
118121
assert image.generation_parameters["negative_prompt"] == negative_prompt1
119-
assert image.generation_parameters["width"] == width
120-
assert image.generation_parameters["height"] == height
122+
# TODO(b/295946075): The service stopped supporting image sizes.
123+
# assert image.generation_parameters["width"] == width
124+
# assert image.generation_parameters["height"] == height
121125
assert image.generation_parameters["seed"] == seed
122126
assert image.generation_parameters["guidance_scale"] == guidance_scale
123127
assert image.generation_parameters["index_of_image_in_batch"] == idx
@@ -127,13 +131,13 @@ def test_image_generation_model_generate_images(self):
127131
image_path = os.path.join(temp_dir, "image.png")
128132
image_response[0].save(location=image_path)
129133
image1 = vision_models.GeneratedImage.load_from_file(image_path)
130-
assert image1._pil_image.size == (width, height)
134+
# assert image1._pil_image.size == (width, height)
131135
assert image1.generation_parameters
132136
assert image1.generation_parameters["prompt"] == prompt1
133137

134138
# Preparing mask
135139
mask_path = os.path.join(temp_dir, "mask.png")
136-
mask_pil_image = PIL_Image.new(mode="RGB", size=(width, height))
140+
mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
137141
mask_pil_image.save(mask_path, format="PNG")
138142
mask_image = vision_models.Image.load_from_file(mask_path)
139143

@@ -150,7 +154,8 @@ def test_image_generation_model_generate_images(self):
150154
)
151155
assert len(image_response2.images) == number_of_images
152156
for idx, image in enumerate(image_response2):
153-
assert image._pil_image.size == (width, height)
157+
# TODO(b/295946075): The service stopped supporting image sizes.
158+
# assert image._pil_image.size == (width, height)
154159
assert image.generation_parameters
155160
assert image.generation_parameters["prompt"] == prompt2
156161
assert image.generation_parameters["seed"] == seed

tests/unit/aiplatform/test_vision_models.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import tempfile
2525
from typing import Any, Dict
26+
import unittest
2627
from unittest import mock
2728

2829
from google.cloud import aiplatform
@@ -175,7 +176,9 @@ def test_generate_images(self):
175176
model = self._get_image_generation_model()
176177

177178
width = 1024
178-
height = 768
179+
# TODO(b/295946075) The service stopped supporting image sizes.
180+
# height = 768
181+
height = 1024
179182
number_of_images = 4
180183
seed = 1
181184
guidance_scale = 15
@@ -200,8 +203,9 @@ def test_generate_images(self):
200203
# Optional:
201204
negative_prompt=negative_prompt1,
202205
number_of_images=number_of_images,
203-
width=width,
204-
height=height,
206+
# TODO(b/295946075) The service stopped supporting image sizes.
207+
# width=width,
208+
# height=height,
205209
seed=seed,
206210
guidance_scale=guidance_scale,
207211
)
@@ -210,8 +214,9 @@ def test_generate_images(self):
210214
actual_instance = predict_kwargs["instances"][0]
211215
assert actual_instance["prompt"] == prompt1
212216
assert actual_instance["negativePrompt"] == negative_prompt1
213-
assert actual_parameters["sampleImageSize"] == str(max(width, height))
214-
assert actual_parameters["aspectRatio"] == f"{width}:{height}"
217+
# TODO(b/295946075) The service stopped supporting image sizes.
218+
# assert actual_parameters["sampleImageSize"] == str(max(width, height))
219+
# assert actual_parameters["aspectRatio"] == f"{width}:{height}"
215220
assert actual_parameters["seed"] == seed
216221
assert actual_parameters["guidanceScale"] == guidance_scale
217222

@@ -221,8 +226,9 @@ def test_generate_images(self):
221226
assert image.generation_parameters
222227
assert image.generation_parameters["prompt"] == prompt1
223228
assert image.generation_parameters["negative_prompt"] == negative_prompt1
224-
assert image.generation_parameters["width"] == width
225-
assert image.generation_parameters["height"] == height
229+
# TODO(b/295946075) The service stopped supporting image sizes.
230+
# assert image.generation_parameters["width"] == width
231+
# assert image.generation_parameters["height"] == height
226232
assert image.generation_parameters["seed"] == seed
227233
assert image.generation_parameters["guidance_scale"] == guidance_scale
228234
assert image.generation_parameters["index_of_image_in_batch"] == idx
@@ -233,13 +239,13 @@ def test_generate_images(self):
233239
image_path = os.path.join(temp_dir, "image.png")
234240
image_response[0].save(location=image_path)
235241
image1 = vision_models.GeneratedImage.load_from_file(image_path)
236-
assert image1._pil_image.size == (width, height)
242+
# assert image1._pil_image.size == (width, height)
237243
assert image1.generation_parameters
238244
assert image1.generation_parameters["prompt"] == prompt1
239245

240246
# Preparing mask
241247
mask_path = os.path.join(temp_dir, "mask.png")
242-
mask_pil_image = PIL_Image.new(mode="RGB", size=(width, height))
248+
mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
243249
mask_pil_image.save(mask_path, format="PNG")
244250
mask_image = vision_models.Image.load_from_file(mask_path)
245251

@@ -273,6 +279,7 @@ def test_generate_images(self):
273279
assert image.generation_parameters["base_image_hash"]
274280
assert image.generation_parameters["mask_hash"]
275281

282+
@unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
276283
def test_generate_images_requests_square_images_by_default(self):
277284
"""Tests that the model class generates square image by default."""
278285
model = self._get_image_generation_model()

vertexai/vision_models/_vision_models.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def _generate_images(
166166
instance = {"prompt": prompt}
167167
shared_generation_parameters = {
168168
"prompt": prompt,
169-
"width": width,
170-
"height": height,
169+
# b/295946075 The service stopped supporting image sizes.
170+
# "width": width,
171+
# "height": height,
171172
"number_of_images_in_batch": number_of_images,
172173
}
173174

@@ -238,8 +239,6 @@ def generate_images(
238239
*,
239240
negative_prompt: Optional[str] = None,
240241
number_of_images: int = 1,
241-
width: Optional[int] = None,
242-
height: Optional[int] = None,
243242
guidance_scale: Optional[float] = None,
244243
seed: Optional[int] = None,
245244
) -> "ImageGenerationResponse":
@@ -250,8 +249,6 @@ def generate_images(
250249
negative_prompt: A description of what you want to omit in
251250
the generated images.
252251
number_of_images: Number of images to generate. Range: 1..8.
253-
width: Width of the image. One of the sizes must be 256 or 1024.
254-
height: Height of the image. One of the sizes must be 256 or 1024.
255252
guidance_scale: Controls the strength of the prompt.
256253
Suggested values are:
257254
* 0-9 (low strength)
@@ -266,8 +263,9 @@ def generate_images(
266263
prompt=prompt,
267264
negative_prompt=negative_prompt,
268265
number_of_images=number_of_images,
269-
width=width,
270-
height=height,
266+
# b/295946075 The service stopped supporting image sizes.
267+
width=None,
268+
height=None,
271269
guidance_scale=guidance_scale,
272270
seed=seed,
273271
)

0 commit comments

Comments
 (0)