Skip to content

Commit e2efdbe

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add safety filter levels, watermark support and person generation support for Imagen 2
Changelog: - Added `add_watermark` option to `generate_image` call for adding a SynthID watermark to generated images. - Added a `edit_mode` option to `edit_image` call. Can now choose between 4 edit modes - - `inpainting-insert` : Edit the image within the masked region. Needs both mask and prompt - `inpainting-remove`: Remove objects within the masked region. Needs only mask - `outpainting`: Extend the image based on the mask area. - `product-image`: Changes background for primary subject of the image - Added a `mask_mode` option to `edit_image` call. Can now choose between 3 mask generation modes, instead of providing masks: - `background`: Select everything except the primary subject(s) of the image - `foreground`: Select the primary subject(s) of the image - `semantic`: Segment one or more of the segmentation classes using class ID - Added a `segmentation_classes` option for passing a list of class IDs when `semantic` mask_mode is used. Can send upto 5 classes - Added a `mask_dilation` option for setting the dilation percentage of mask - Added a `product_position` option to allow repositioning of products in the image. Supported values are: - `reposition`: Products can be repositioned - `fixed`: Product location is fixed - Added a `output_mime_type` option to select which image format should the output be returned as. Supported values are: - `image/png` - `image/jpeg` - Added a `compression_quality` option to select compression quality when output is `image/jpeg`. - Added a safety filter level for selecting the level of prompt and image filtering by Responsible AI filters. Supported values are: - `"block_most"` : The strictest filter. Blocks most - `"block_some"` : Second most strict filter. Blocks some prompts and images - `"block_few"` : Blocks a few prompts and images - `"block_fewest"`: Blocks fewest prompts and images - Added an option to control person generation. Supported values are: - `"dont_allow"` : Don't generate people at all - `"allow_adults"`: Generate adults, but not children - `"allow_all"` : Allows all person generation - Added the WatermarkVerificationModel to check if an image has a SynthID watermark. The publisher model is `imageverification@001`. The model object contains just one call, `verify_image`. `verify_image` takes only an image as the input and returns a string with one of 2 values: - `ACCEPT` : The image contains a watermark - `REJECT` : The image does not contain a watermark PiperOrigin-RevId: 617924430
1 parent 181dc7a commit e2efdbe

File tree

4 files changed

+675
-35
lines changed

4 files changed

+675
-35
lines changed

tests/system/aiplatform/test_vision_models.py

+174-1
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,83 @@ def test_image_generation_model_generate_images(self):
162162
assert image.generation_parameters["index_of_image_in_batch"] == idx
163163
assert image.generation_parameters["language"] == language
164164

165+
for width, height in [(1, 1), (9, 16), (16, 9), (4, 3), (3, 4)]:
166+
prompt_aspect_ratio = "A street lit up on a rainy night"
167+
model = vision_models.ImageGenerationModel.from_pretrained(
168+
"imagegeneration@006"
169+
)
170+
171+
number_of_images = 4
172+
seed = 1
173+
guidance_scale = 15
174+
language = "en"
175+
aspect_ratio = f"{width}:{height}"
176+
177+
image_response = model.generate_images(
178+
prompt=prompt_aspect_ratio,
179+
number_of_images=number_of_images,
180+
aspect_ratio=aspect_ratio,
181+
seed=seed,
182+
guidance_scale=guidance_scale,
183+
language=language,
184+
)
185+
186+
assert len(image_response.images) == number_of_images
187+
for idx, image in enumerate(image_response):
188+
assert image.generation_parameters
189+
assert image.generation_parameters["prompt"] == prompt_aspect_ratio
190+
assert image.generation_parameters["aspect_ratio"] == aspect_ratio
191+
assert image.generation_parameters["seed"] == seed
192+
assert image.generation_parameters["guidance_scale"] == guidance_scale
193+
assert image.generation_parameters["index_of_image_in_batch"] == idx
194+
assert image.generation_parameters["language"] == language
195+
assert (
196+
abs(
197+
float(image.size[0]) / float(image.size[1])
198+
- float(width) / float(height)
199+
)
200+
<= 0.001
201+
)
202+
203+
person_generation_prompts = [
204+
"A street lit up on a rainy night",
205+
"A woman walking down a street lit up on a rainy night",
206+
"A child walking down a street lit up on a rainy night",
207+
"A man walking down a street lit up on a rainy night",
208+
]
209+
210+
person_generation_levels = ["dont_allow", "allow_adult", "allow_all"]
211+
212+
for i in range(0, 3):
213+
for j in range(0, i + 1):
214+
image_response = model.generate_images(
215+
prompt=person_generation_prompts[j],
216+
number_of_images=number_of_images,
217+
seed=seed,
218+
guidance_scale=guidance_scale,
219+
language=language,
220+
person_generation=person_generation_levels[j],
221+
)
222+
if i == j:
223+
assert len(image_response.images) == number_of_images
224+
else:
225+
assert len(image_response.images) < number_of_images
226+
for idx, image in enumerate(image_response):
227+
assert (
228+
image.generation_parameters["person_generation"]
229+
== person_generation_levels[j]
230+
)
231+
assert (
232+
image.generation_parameters["prompt"]
233+
== person_generation_prompts[j]
234+
)
235+
assert image.generation_parameters["seed"] == seed
236+
assert (
237+
image.generation_parameters["guidance_scale"] == guidance_scale
238+
)
239+
assert image.generation_parameters["index_of_image_in_batch"] == idx
240+
assert image.generation_parameters["language"] == language
241+
165242
# Test saving and loading images
166243
with tempfile.TemporaryDirectory() as temp_dir:
167244
image_path = os.path.join(temp_dir, "image.png")
@@ -178,8 +255,14 @@ def test_image_generation_model_generate_images(self):
178255
mask_pil_image.save(mask_path, format="PNG")
179256
mask_image = vision_models.Image.load_from_file(mask_path)
180257

181-
# Test generating image from base image
258+
# Test generating image from base image
182259
prompt2 = "Ancient book style"
260+
edit_mode = "inpainting-insert"
261+
mask_mode = "foreground"
262+
mask_dilation = 0.06
263+
product_position = "fixed"
264+
output_mime_type = "image/jpeg"
265+
compression_quality = 0.90
183266
image_response2 = model.edit_image(
184267
prompt=prompt2,
185268
# Optional:
@@ -188,6 +271,12 @@ def test_image_generation_model_generate_images(self):
188271
guidance_scale=guidance_scale,
189272
base_image=image1,
190273
mask=mask_image,
274+
edit_mode=edit_mode,
275+
mask_mode=mask_mode,
276+
mask_dilation=mask_dilation,
277+
product_position=product_position,
278+
output_mime_type=output_mime_type,
279+
compression_quality=compression_quality,
191280
language=language,
192281
)
193282
assert len(image_response2.images) == number_of_images
@@ -199,6 +288,90 @@ def test_image_generation_model_generate_images(self):
199288
assert image.generation_parameters["seed"] == seed
200289
assert image.generation_parameters["guidance_scale"] == guidance_scale
201290
assert image.generation_parameters["index_of_image_in_batch"] == idx
291+
assert image.generation_parameters["edit_mode"] == edit_mode
292+
assert image.generation_parameters["mask_mode"] == mask_mode
293+
assert image.generation_parameters["mask_dilation"] == mask_dilation
294+
assert image.generation_parameters["product_position"] == product_position
295+
assert image.generation_parameters["mime_type"] == output_mime_type
296+
assert (
297+
image.generation_parameters["compression_quality"]
298+
== compression_quality
299+
)
300+
assert image.generation_parameters["language"] == language
301+
assert "base_image_hash" in image.generation_parameters
302+
assert "mask_hash" in image.generation_parameters
303+
304+
prompt3 = "Chocolate chip cookies"
305+
edit_mode = "inpainting-insert"
306+
mask_mode = "semantic"
307+
segmentation_classes = [1, 13, 17, 9, 18]
308+
product_position = "fixed"
309+
output_mime_type = "image/png"
310+
311+
image_response3 = model.edit_image(
312+
prompt=prompt3,
313+
number_of_images=number_of_images,
314+
seed=seed,
315+
guidance_scale=guidance_scale,
316+
base_image=image1,
317+
mask=mask_image,
318+
edit_mode=edit_mode,
319+
mask_mode=mask_mode,
320+
segmentation_classes=segmentation_classes,
321+
product_position=product_position,
322+
output_mime_type=output_mime_type,
323+
language=language,
324+
)
325+
326+
assert len(image_response3.images) == number_of_images
327+
for idx, image in enumerate(image_response3):
328+
assert image.generation_parameters
329+
assert image.generation_parameters["prompt"] == prompt3
330+
assert image.generation_parameters["seed"] == seed
331+
assert image.generation_parameters["guidance_scale"] == guidance_scale
332+
assert image.generation_parameters["index_of_image_in_batch"] == idx
333+
assert image.generation_parameters["edit_mode"] == edit_mode
334+
assert image.generation_parameters["mask_mode"] == mask_mode
335+
assert (
336+
image.generation_parameters["segmentation_classes"]
337+
== segmentation_classes
338+
)
339+
assert image.generation_parameters["product_position"] == product_position
340+
assert image.generation_parameters["mime_type"] == output_mime_type
202341
assert image.generation_parameters["language"] == language
203342
assert "base_image_hash" in image.generation_parameters
204343
assert "mask_hash" in image.generation_parameters
344+
345+
def test_image_verification_model_verify_image(self):
346+
"""Tests the image verification model verifying watermark presence in an image."""
347+
verification_model = vision_models.ImageVerificationModel.from_pretrained(
348+
"imageverification@001"
349+
)
350+
model = vision_models.ImageGenerationModel.from_pretrained(
351+
"imagegeneration@005"
352+
)
353+
seed = 1
354+
guidance_scale = 15
355+
language = "en"
356+
image_verification_response = verification_model.verify_image(
357+
image=_create_blank_image()
358+
)
359+
assert image_verification_response["decision"] == "REJECT"
360+
361+
prompt = "A street lit up on a rainy night"
362+
image_response = model.generate_images(
363+
prompt=prompt,
364+
number_of_images=1,
365+
seed=seed,
366+
guidance_scale=guidance_scale,
367+
language=language,
368+
add_watermark=True,
369+
)
370+
assert len(image_response.images) == 1
371+
372+
image_with_watermark = vision_models.Image(image_response.images[0].image_bytes)
373+
374+
image_verification_response = verification_model.verify_image(
375+
image_with_watermark
376+
)
377+
assert image_verification_response["decision"] == "ACCEPT"

0 commit comments

Comments
 (0)