Skip to content

Commit d4c3c66

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: LVM - Update a few flag values and interaction between different parameters
Changed flags: `maskMode` is it's own struct in the API, encapsulating `maskType` and `classes` `maskMode` can only be set (even to an empty value) if two conditions are satisfied: - `mask` is None - `editMode` is not `product-image` The logic now reflects this PiperOrigin-RevId: 620092866
1 parent 60b44f2 commit d4c3c66

File tree

2 files changed

+62
-17
lines changed

2 files changed

+62
-17
lines changed

tests/unit/aiplatform/test_vision_models.py

+53-10
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,7 @@ def test_generate_images_gcs(self):
458458
) as mock_predict:
459459
prompt2 = "Ancient book style"
460460
edit_mode = "inpainting-insert"
461-
mask_mode = "background"
462461
mask_dilation = 0.06
463-
product_position = "fixed"
464462
output_mime_type = "image/jpeg"
465463
compression_quality = 80
466464
safety_filter_level = "block_fewest"
@@ -474,9 +472,7 @@ def test_generate_images_gcs(self):
474472
base_image=image1,
475473
mask=mask_image,
476474
edit_mode=edit_mode,
477-
mask_mode=mask_mode,
478475
mask_dilation=mask_dilation,
479-
product_position=product_position,
480476
output_mime_type=output_mime_type,
481477
compression_quality=compression_quality,
482478
safety_filter_level=safety_filter_level,
@@ -489,11 +485,7 @@ def test_generate_images_gcs(self):
489485
assert actual_instance["image"]["gcsUri"]
490486
assert actual_instance["mask"]["image"]["gcsUri"]
491487
assert actual_parameters["editConfig"]["editMode"] == edit_mode
492-
assert actual_parameters["editConfig"]["maskMode"] == mask_mode
493488
assert actual_parameters["editConfig"]["maskDilation"] == mask_dilation
494-
assert (
495-
actual_parameters["editConfig"]["productPosition"] == product_position
496-
)
497489
assert actual_parameters["outputOptions"]["mimeType"] == output_mime_type
498490
assert (
499491
actual_parameters["outputOptions"]["compressionQuality"]
@@ -509,9 +501,7 @@ def test_generate_images_gcs(self):
509501
assert image.generation_parameters["base_image_uri"]
510502
assert image.generation_parameters["mask_uri"]
511503
assert image.generation_parameters["edit_mode"] == edit_mode
512-
assert image.generation_parameters["mask_mode"] == mask_mode
513504
assert image.generation_parameters["mask_dilation"] == mask_dilation
514-
assert image.generation_parameters["product_position"] == product_position
515505
assert image.generation_parameters["mime_type"] == output_mime_type
516506
assert (
517507
image.generation_parameters["compression_quality"]
@@ -522,6 +512,59 @@ def test_generate_images_gcs(self):
522512
== safety_filter_level
523513
)
524514
assert image.generation_parameters["person_generation"] == person_generation
515+
with mock.patch.object(
516+
target=prediction_service_client.PredictionServiceClient,
517+
attribute="predict",
518+
return_value=gca_predict_response,
519+
) as mock_predict:
520+
prompt3 = "Ancient book style"
521+
edit_mode = "inpainting-insert"
522+
mask_dilation = 0.06
523+
output_mime_type = "image/jpeg"
524+
compression_quality = 80
525+
safety_filter_level = "block_fewest"
526+
person_generation = "allow_all"
527+
mask_mode = "background"
528+
529+
image_response3 = model.edit_image(
530+
prompt=prompt3,
531+
base_image=image1,
532+
number_of_images=number_of_images,
533+
edit_mode=edit_mode,
534+
mask_dilation=mask_dilation,
535+
mask_mode=mask_mode,
536+
output_mime_type=output_mime_type,
537+
compression_quality=compression_quality,
538+
safety_filter_level=safety_filter_level,
539+
person_generation=person_generation,
540+
)
541+
542+
predict_kwargs = mock_predict.call_args[1]
543+
actual_parameters = predict_kwargs["parameters"]
544+
actual_instance = predict_kwargs["instances"][0]
545+
assert actual_instance["prompt"] == prompt3
546+
assert actual_instance["image"]["gcsUri"]
547+
assert actual_parameters["editConfig"]["editMode"] == edit_mode
548+
assert actual_parameters["editConfig"]["maskMode"]["maskType"] == mask_mode
549+
assert actual_parameters["editConfig"]["maskDilation"] == mask_dilation
550+
assert actual_parameters["outputOptions"]["mimeType"] == output_mime_type
551+
assert (
552+
actual_parameters["outputOptions"]["compressionQuality"]
553+
== compression_quality
554+
)
555+
556+
assert len(image_response3.images) == number_of_images
557+
for image in image_response3:
558+
assert image.generation_parameters
559+
assert image.generation_parameters["prompt"] == prompt3
560+
assert image.generation_parameters["base_image_uri"]
561+
assert image.generation_parameters["edit_mode"] == edit_mode
562+
assert image.generation_parameters["mask_dilation"] == mask_dilation
563+
assert image.generation_parameters["mime_type"] == output_mime_type
564+
assert (
565+
image.generation_parameters["compression_quality"]
566+
== compression_quality
567+
)
525568

526569
@unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
527570
def test_generate_images_requests_square_images_by_default(self):

vertexai/vision_models/_vision_models.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -497,13 +497,15 @@ class ID
497497
parameters["editConfig"]["editMode"] = edit_mode
498498
shared_generation_parameters["edit_mode"] = edit_mode
499499

500-
if mask_mode is not None:
501-
parameters["editConfig"]["maskMode"] = mask_mode
502-
shared_generation_parameters["mask_mode"] = mask_mode
503-
504-
if segmentation_classes is not None:
505-
parameters["editConfig"]["classes"] = segmentation_classes
506-
shared_generation_parameters["classes"] = segmentation_classes
500+
if mask is None and edit_mode != "product-image":
501+
parameters["editConfig"]["maskMode"] = {}
502+
if mask_mode is not None:
503+
parameters["editConfig"]["maskMode"]["maskType"] = mask_mode
504+
shared_generation_parameters["mask_mode"] = mask_mode
505+
506+
if segmentation_classes is not None:
507+
parameters["editConfig"]["maskMode"]["classes"] = segmentation_classes
508+
shared_generation_parameters["classes"] = segmentation_classes
507509

508510
if mask_dilation is not None:
509511
parameters["editConfig"]["maskDilation"] = mask_dilation

0 commit comments

Comments
 (0)