Skip to content

Commit 4109ea8

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LVM - Add GCS URI support for Imagen Models (imagetext, imagegeneration)
PiperOrigin-RevId: 606401323
1 parent 32c7197 commit 4109ea8

File tree

2 files changed

+227
-37
lines changed

2 files changed

+227
-37
lines changed

tests/unit/aiplatform/test_vision_models.py

+125
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ def make_image_generation_response(
114114
return {"predictions": predictions}
115115

116116

117+
def make_image_generation_response_gcs(count: int = 1) -> Dict[str, Any]:
118+
predictions = []
119+
for _ in range(count):
120+
predictions.append(
121+
{
122+
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
123+
"mimeType": "image/png",
124+
}
125+
)
126+
return {"predictions": predictions}
127+
128+
117129
def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
118130
predictions = {
119131
"bytesBase64Encoded": make_image_base64(upscale_size, upscale_size),
@@ -122,6 +134,14 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
122134
return {"predictions": [predictions]}
123135

124136

137+
def make_image_upscale_response_gcs() -> Dict[str, Any]:
138+
predictions = {
139+
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
140+
"mimeType": "image/png",
141+
}
142+
return {"predictions": [predictions]}
143+
144+
125145
def generate_image_from_file(
126146
width: int = 100, height: int = 100
127147
) -> ga_vision_models.Image:
@@ -332,6 +352,111 @@ def test_generate_images(self):
332352
assert image.generation_parameters["mask_hash"]
333353
assert image.generation_parameters["language"] == language
334354

355+
def test_generate_images_gcs(self):
356+
"""Tests the image generation model."""
357+
model = self._get_image_generation_model()
358+
359+
# TODO(b/295946075) The service stopped supporting image sizes.
360+
# height = 768
361+
number_of_images = 4
362+
seed = 1
363+
guidance_scale = 15
364+
language = "en"
365+
output_gcs_uri = "gs://test-bucket/"
366+
367+
image_generation_response = make_image_generation_response_gcs(
368+
count=number_of_images
369+
)
370+
gca_predict_response = gca_prediction_service.PredictResponse()
371+
gca_predict_response.predictions.extend(
372+
image_generation_response["predictions"]
373+
)
374+
375+
with mock.patch.object(
376+
target=prediction_service_client.PredictionServiceClient,
377+
attribute="predict",
378+
return_value=gca_predict_response,
379+
) as mock_predict:
380+
prompt1 = "Astronaut riding a horse"
381+
negative_prompt1 = "bad quality"
382+
image_response = model.generate_images(
383+
prompt=prompt1,
384+
# Optional:
385+
negative_prompt=negative_prompt1,
386+
number_of_images=number_of_images,
387+
# TODO(b/295946075) The service stopped supporting image sizes.
388+
# width=width,
389+
# height=height,
390+
seed=seed,
391+
guidance_scale=guidance_scale,
392+
language=language,
393+
output_gcs_uri=output_gcs_uri,
394+
)
395+
predict_kwargs = mock_predict.call_args[1]
396+
actual_parameters = predict_kwargs["parameters"]
397+
actual_instance = predict_kwargs["instances"][0]
398+
assert actual_instance["prompt"] == prompt1
399+
assert actual_parameters["negativePrompt"] == negative_prompt1
400+
# TODO(b/295946075) The service stopped supporting image sizes.
401+
# assert actual_parameters["sampleImageSize"] == str(max(width, height))
402+
# assert actual_parameters["aspectRatio"] == f"{width}:{height}"
403+
assert actual_parameters["seed"] == seed
404+
assert actual_parameters["guidanceScale"] == guidance_scale
405+
assert actual_parameters["language"] == language
406+
assert actual_parameters["storageUri"] == output_gcs_uri
407+
408+
assert len(image_response.images) == number_of_images
409+
for idx, image in enumerate(image_response):
410+
assert image.generation_parameters
411+
assert image.generation_parameters["prompt"] == prompt1
412+
assert image.generation_parameters["negative_prompt"] == negative_prompt1
413+
# TODO(b/295946075) The service stopped supporting image sizes.
414+
# assert image.generation_parameters["width"] == width
415+
# assert image.generation_parameters["height"] == height
416+
assert image.generation_parameters["seed"] == seed
417+
assert image.generation_parameters["guidance_scale"] == guidance_scale
418+
assert image.generation_parameters["language"] == language
419+
assert image.generation_parameters["index_of_image_in_batch"] == idx
420+
assert image.generation_parameters["storage_uri"] == output_gcs_uri
421+
422+
image1 = generate_image_from_gcs_uri()
423+
mask_image = generate_image_from_gcs_uri()
424+
425+
# Test generating image from base image
426+
with mock.patch.object(
427+
target=prediction_service_client.PredictionServiceClient,
428+
attribute="predict",
429+
return_value=gca_predict_response,
430+
) as mock_predict:
431+
prompt2 = "Ancient book style"
432+
image_response2 = model.edit_image(
433+
prompt=prompt2,
434+
# Optional:
435+
number_of_images=number_of_images,
436+
seed=seed,
437+
guidance_scale=guidance_scale,
438+
base_image=image1,
439+
mask=mask_image,
440+
language=language,
441+
output_gcs_uri=output_gcs_uri,
442+
)
443+
predict_kwargs = mock_predict.call_args[1]
444+
actual_parameters = predict_kwargs["parameters"]
445+
actual_instance = predict_kwargs["instances"][0]
446+
assert actual_instance["prompt"] == prompt2
447+
assert actual_instance["image"]["gcsUri"]
448+
assert actual_instance["mask"]["image"]["gcsUri"]
449+
assert actual_parameters["language"] == language
450+
451+
assert len(image_response2.images) == number_of_images
452+
for image in image_response2:
453+
assert image.generation_parameters
454+
assert image.generation_parameters["prompt"] == prompt2
455+
assert image.generation_parameters["base_image_uri"]
456+
assert image.generation_parameters["mask_uri"]
457+
assert image.generation_parameters["language"] == language
458+
assert image.generation_parameters["storage_uri"] == output_gcs_uri
459+
335460
@unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
336461
def test_generate_images_requests_square_images_by_default(self):
337462
"""Tests that the model class generates square image by default."""

0 commit comments

Comments
 (0)