@@ -114,6 +114,18 @@ def make_image_generation_response(
114
114
return {"predictions" : predictions }
115
115
116
116
117
+ def make_image_generation_response_gcs (count : int = 1 ) -> Dict [str , Any ]:
118
+ predictions = []
119
+ for _ in range (count ):
120
+ predictions .append (
121
+ {
122
+ "gcsUri" : "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" ,
123
+ "mimeType" : "image/png" ,
124
+ }
125
+ )
126
+ return {"predictions" : predictions }
127
+
128
+
117
129
def make_image_upscale_response (upscale_size : int ) -> Dict [str , Any ]:
118
130
predictions = {
119
131
"bytesBase64Encoded" : make_image_base64 (upscale_size , upscale_size ),
@@ -122,6 +134,14 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
122
134
return {"predictions" : [predictions ]}
123
135
124
136
137
+ def make_image_upscale_response_gcs () -> Dict [str , Any ]:
138
+ predictions = {
139
+ "gcsUri" : "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" ,
140
+ "mimeType" : "image/png" ,
141
+ }
142
+ return {"predictions" : [predictions ]}
143
+
144
+
125
145
def generate_image_from_file (
126
146
width : int = 100 , height : int = 100
127
147
) -> ga_vision_models .Image :
@@ -332,6 +352,111 @@ def test_generate_images(self):
332
352
assert image .generation_parameters ["mask_hash" ]
333
353
assert image .generation_parameters ["language" ] == language
334
354
355
+ def test_generate_images_gcs (self ):
356
+ """Tests the image generation model."""
357
+ model = self ._get_image_generation_model ()
358
+
359
+ # TODO(b/295946075) The service stopped supporting image sizes.
360
+ # height = 768
361
+ number_of_images = 4
362
+ seed = 1
363
+ guidance_scale = 15
364
+ language = "en"
365
+ output_gcs_uri = "gs://test-bucket/"
366
+
367
+ image_generation_response = make_image_generation_response_gcs (
368
+ count = number_of_images
369
+ )
370
+ gca_predict_response = gca_prediction_service .PredictResponse ()
371
+ gca_predict_response .predictions .extend (
372
+ image_generation_response ["predictions" ]
373
+ )
374
+
375
+ with mock .patch .object (
376
+ target = prediction_service_client .PredictionServiceClient ,
377
+ attribute = "predict" ,
378
+ return_value = gca_predict_response ,
379
+ ) as mock_predict :
380
+ prompt1 = "Astronaut riding a horse"
381
+ negative_prompt1 = "bad quality"
382
+ image_response = model .generate_images (
383
+ prompt = prompt1 ,
384
+ # Optional:
385
+ negative_prompt = negative_prompt1 ,
386
+ number_of_images = number_of_images ,
387
+ # TODO(b/295946075) The service stopped supporting image sizes.
388
+ # width=width,
389
+ # height=height,
390
+ seed = seed ,
391
+ guidance_scale = guidance_scale ,
392
+ language = language ,
393
+ output_gcs_uri = output_gcs_uri ,
394
+ )
395
+ predict_kwargs = mock_predict .call_args [1 ]
396
+ actual_parameters = predict_kwargs ["parameters" ]
397
+ actual_instance = predict_kwargs ["instances" ][0 ]
398
+ assert actual_instance ["prompt" ] == prompt1
399
+ assert actual_parameters ["negativePrompt" ] == negative_prompt1
400
+ # TODO(b/295946075) The service stopped supporting image sizes.
401
+ # assert actual_parameters["sampleImageSize"] == str(max(width, height))
402
+ # assert actual_parameters["aspectRatio"] == f"{width}:{height}"
403
+ assert actual_parameters ["seed" ] == seed
404
+ assert actual_parameters ["guidanceScale" ] == guidance_scale
405
+ assert actual_parameters ["language" ] == language
406
+ assert actual_parameters ["storageUri" ] == output_gcs_uri
407
+
408
+ assert len (image_response .images ) == number_of_images
409
+ for idx , image in enumerate (image_response ):
410
+ assert image .generation_parameters
411
+ assert image .generation_parameters ["prompt" ] == prompt1
412
+ assert image .generation_parameters ["negative_prompt" ] == negative_prompt1
413
+ # TODO(b/295946075) The service stopped supporting image sizes.
414
+ # assert image.generation_parameters["width"] == width
415
+ # assert image.generation_parameters["height"] == height
416
+ assert image .generation_parameters ["seed" ] == seed
417
+ assert image .generation_parameters ["guidance_scale" ] == guidance_scale
418
+ assert image .generation_parameters ["language" ] == language
419
+ assert image .generation_parameters ["index_of_image_in_batch" ] == idx
420
+ assert image .generation_parameters ["storage_uri" ] == output_gcs_uri
421
+
422
+ image1 = generate_image_from_gcs_uri ()
423
+ mask_image = generate_image_from_gcs_uri ()
424
+
425
+ # Test generating image from base image
426
+ with mock .patch .object (
427
+ target = prediction_service_client .PredictionServiceClient ,
428
+ attribute = "predict" ,
429
+ return_value = gca_predict_response ,
430
+ ) as mock_predict :
431
+ prompt2 = "Ancient book style"
432
+ image_response2 = model .edit_image (
433
+ prompt = prompt2 ,
434
+ # Optional:
435
+ number_of_images = number_of_images ,
436
+ seed = seed ,
437
+ guidance_scale = guidance_scale ,
438
+ base_image = image1 ,
439
+ mask = mask_image ,
440
+ language = language ,
441
+ output_gcs_uri = output_gcs_uri ,
442
+ )
443
+ predict_kwargs = mock_predict .call_args [1 ]
444
+ actual_parameters = predict_kwargs ["parameters" ]
445
+ actual_instance = predict_kwargs ["instances" ][0 ]
446
+ assert actual_instance ["prompt" ] == prompt2
447
+ assert actual_instance ["image" ]["gcsUri" ]
448
+ assert actual_instance ["mask" ]["image" ]["gcsUri" ]
449
+ assert actual_parameters ["language" ] == language
450
+
451
+ assert len (image_response2 .images ) == number_of_images
452
+ for image in image_response2 :
453
+ assert image .generation_parameters
454
+ assert image .generation_parameters ["prompt" ] == prompt2
455
+ assert image .generation_parameters ["base_image_uri" ]
456
+ assert image .generation_parameters ["mask_uri" ]
457
+ assert image .generation_parameters ["language" ] == language
458
+ assert image .generation_parameters ["storage_uri" ] == output_gcs_uri
459
+
335
460
@unittest .skip (reason = "b/295946075 The service stopped supporting image sizes." )
336
461
def test_generate_images_requests_square_images_by_default (self ):
337
462
"""Tests that the model class generates square image by default."""
0 commit comments