Skip to content

Commit 791eff5

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LVM - Added multi-language support for ImageGenerationModel
PiperOrigin-RevId: 580285009
1 parent 9c4decc commit 791eff5

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

tests/system/aiplatform/test_vision_models.py

+6
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def test_image_generation_model_generate_images(self):
9797
number_of_images = 4
9898
seed = 1
9999
guidance_scale = 15
100+
language = "en"
100101

101102
prompt1 = "Astronaut riding a horse"
102103
negative_prompt1 = "bad quality"
@@ -110,6 +111,7 @@ def test_image_generation_model_generate_images(self):
110111
# height=height,
111112
seed=seed,
112113
guidance_scale=guidance_scale,
114+
language=language,
113115
)
114116

115117
assert len(image_response.images) == number_of_images
@@ -125,6 +127,7 @@ def test_image_generation_model_generate_images(self):
125127
assert image.generation_parameters["seed"] == seed
126128
assert image.generation_parameters["guidance_scale"] == guidance_scale
127129
assert image.generation_parameters["index_of_image_in_batch"] == idx
130+
assert image.generation_parameters["language"] == language
128131

129132
# Test saving and loading images
130133
with tempfile.TemporaryDirectory() as temp_dir:
@@ -134,6 +137,7 @@ def test_image_generation_model_generate_images(self):
134137
# assert image1._pil_image.size == (width, height)
135138
assert image1.generation_parameters
136139
assert image1.generation_parameters["prompt"] == prompt1
140+
assert image1.generation_parameters["language"] == language
137141

138142
# Preparing mask
139143
mask_path = os.path.join(temp_dir, "mask.png")
@@ -151,6 +155,7 @@ def test_image_generation_model_generate_images(self):
151155
guidance_scale=guidance_scale,
152156
base_image=image1,
153157
mask=mask_image,
158+
language=language,
154159
)
155160
assert len(image_response2.images) == number_of_images
156161
for idx, image in enumerate(image_response2):
@@ -161,5 +166,6 @@ def test_image_generation_model_generate_images(self):
161166
assert image.generation_parameters["seed"] == seed
162167
assert image.generation_parameters["guidance_scale"] == guidance_scale
163168
assert image.generation_parameters["index_of_image_in_batch"] == idx
169+
assert image.generation_parameters["language"] == language
164170
assert "base_image_hash" in image.generation_parameters
165171
assert "mask_hash" in image.generation_parameters

tests/unit/aiplatform/test_vision_models.py

+9
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def test_generate_images(self):
220220
number_of_images = 4
221221
seed = 1
222222
guidance_scale = 15
223+
language = "en"
223224

224225
image_generation_response = make_image_generation_response(
225226
width=width, height=height, count=number_of_images
@@ -246,6 +247,7 @@ def test_generate_images(self):
246247
# height=height,
247248
seed=seed,
248249
guidance_scale=guidance_scale,
250+
language=language,
249251
)
250252
predict_kwargs = mock_predict.call_args[1]
251253
actual_parameters = predict_kwargs["parameters"]
@@ -257,6 +259,7 @@ def test_generate_images(self):
257259
# assert actual_parameters["aspectRatio"] == f"{width}:{height}"
258260
assert actual_parameters["seed"] == seed
259261
assert actual_parameters["guidanceScale"] == guidance_scale
262+
assert actual_parameters["language"] == language
260263

261264
assert len(image_response.images) == number_of_images
262265
for idx, image in enumerate(image_response):
@@ -269,6 +272,7 @@ def test_generate_images(self):
269272
# assert image.generation_parameters["height"] == height
270273
assert image.generation_parameters["seed"] == seed
271274
assert image.generation_parameters["guidance_scale"] == guidance_scale
275+
assert image.generation_parameters["language"] == language
272276
assert image.generation_parameters["index_of_image_in_batch"] == idx
273277
image.show()
274278

@@ -280,6 +284,7 @@ def test_generate_images(self):
280284
# assert image1._pil_image.size == (width, height)
281285
assert image1.generation_parameters
282286
assert image1.generation_parameters["prompt"] == prompt1
287+
assert image1.generation_parameters["language"] == language
283288

284289
# Preparing mask
285290
mask_path = os.path.join(temp_dir, "mask.png")
@@ -302,12 +307,15 @@ def test_generate_images(self):
302307
guidance_scale=guidance_scale,
303308
base_image=image1,
304309
mask=mask_image,
310+
language=language,
305311
)
306312
predict_kwargs = mock_predict.call_args[1]
313+
actual_parameters = predict_kwargs["parameters"]
307314
actual_instance = predict_kwargs["instances"][0]
308315
assert actual_instance["prompt"] == prompt2
309316
assert actual_instance["image"]["bytesBase64Encoded"]
310317
assert actual_instance["mask"]["image"]["bytesBase64Encoded"]
318+
assert actual_parameters["language"] == language
311319

312320
assert len(image_response2.images) == number_of_images
313321
for image in image_response2:
@@ -316,6 +324,7 @@ def test_generate_images(self):
316324
assert image.generation_parameters["prompt"] == prompt2
317325
assert image.generation_parameters["base_image_hash"]
318326
assert image.generation_parameters["mask_hash"]
327+
assert image.generation_parameters["language"] == language
319328

320329
@unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
321330
def test_generate_images_requests_square_images_by_default(self):

vertexai/vision_models/_vision_models.py

+18
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _generate_images(
142142
seed: Optional[int] = None,
143143
base_image: Optional["Image"] = None,
144144
mask: Optional["Image"] = None,
145+
language:Optional[str] = None,
145146
) -> "ImageGenerationResponse":
146147
"""Generates images from text prompt.
147148
@@ -160,6 +161,9 @@ def _generate_images(
160161
seed: Image generation random seed.
161162
base_image: Base image to use for the image generation.
162163
mask: Mask for the base image.
164+
language: Language of the text prompt for the image. Default: None.
165+
Supported values are `"en"` for English, `"hi"` for Hindi,
166+
`"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection.
163167
164168
Returns:
165169
An `ImageGenerationResponse` object.
@@ -216,6 +220,10 @@ def _generate_images(
216220
parameters["guidanceScale"] = guidance_scale
217221
shared_generation_parameters["guidance_scale"] = guidance_scale
218222

223+
if language is not None:
224+
parameters["language"] = language
225+
shared_generation_parameters["language"] = language
226+
219227
response = self._endpoint.predict(
220228
instances=[instance],
221229
parameters=parameters,
@@ -241,6 +249,7 @@ def generate_images(
241249
negative_prompt: Optional[str] = None,
242250
number_of_images: int = 1,
243251
guidance_scale: Optional[float] = None,
252+
language: Optional[str] = None,
244253
seed: Optional[int] = None,
245254
) -> "ImageGenerationResponse":
246255
"""Generates images from text prompt.
@@ -255,6 +264,9 @@ def generate_images(
255264
* 0-9 (low strength)
256265
* 10-20 (medium strength)
257266
* 21+ (high strength)
267+
language: Language of the text prompt for the image. Default: None.
268+
Supported values are `"en"` for English, `"hi"` for Hindi,
269+
`"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection.
258270
seed: Image generation random seed.
259271
260272
Returns:
@@ -268,6 +280,7 @@ def generate_images(
268280
width=None,
269281
height=None,
270282
guidance_scale=guidance_scale,
283+
language=language,
271284
seed=seed,
272285
)
273286

@@ -280,6 +293,7 @@ def edit_image(
280293
negative_prompt: Optional[str] = None,
281294
number_of_images: int = 1,
282295
guidance_scale: Optional[float] = None,
296+
language: Optional[str] = None,
283297
seed: Optional[int] = None,
284298
) -> "ImageGenerationResponse":
285299
"""Edits an existing image based on text prompt.
@@ -296,6 +310,9 @@ def edit_image(
296310
* 0-9 (low strength)
297311
* 10-20 (medium strength)
298312
* 21+ (high strength)
313+
language: Language of the text prompt for the image. Default: None.
314+
Supported values are `"en"` for English, `"hi"` for Hindi,
315+
`"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection.
299316
seed: Image generation random seed.
300317
301318
Returns:
@@ -309,6 +326,7 @@ def edit_image(
309326
seed=seed,
310327
base_image=base_image,
311328
mask=mask,
329+
language=language,
312330
)
313331

314332
def upscale_image(

0 commit comments

Comments
 (0)