Skip to content

Commit 38ec40a

Browse files
sararobcopybara-github
authored andcommitted
feat: add support for providing only text to MultiModalEmbeddingModel.get_embeddings()
PiperOrigin-RevId: 553809703
1 parent ff47513 commit 38ec40a

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

tests/unit/aiplatform/test_vision_models.py

+30
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,33 @@ def test_image_embedding_model_with_image_and_text(self):
264264

265265
assert embedding_response.image_embedding == test_embeddings
266266
assert embedding_response.text_embedding == test_embeddings
267+
268+
def test_image_embedding_model_with_only_text(self):
269+
aiplatform.init(
270+
project=_TEST_PROJECT,
271+
location=_TEST_LOCATION,
272+
)
273+
with mock.patch.object(
274+
target=model_garden_service_client.ModelGardenServiceClient,
275+
attribute="get_publisher_model",
276+
return_value=gca_publisher_model.PublisherModel(
277+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
278+
),
279+
):
280+
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
281+
"multimodalembedding@001"
282+
)
283+
284+
test_embeddings = [0, 0]
285+
gca_predict_response = gca_prediction_service.PredictResponse()
286+
gca_predict_response.predictions.append({"textEmbedding": test_embeddings})
287+
288+
with mock.patch.object(
289+
target=prediction_service_client.PredictionServiceClient,
290+
attribute="predict",
291+
return_value=gca_predict_response,
292+
):
293+
embedding_response = model.get_embeddings(contextual_text="hello world")
294+
295+
assert not embedding_response.image_embedding
296+
assert embedding_response.text_embedding == test_embeddings

vertexai/vision_models/_vision_models.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -234,28 +234,32 @@ class MultiModalEmbeddingModel(_model_garden_models._ModelGardenModel):
234234
)
235235

236236
def get_embeddings(
237-
self, image: Image, contextual_text: Optional[str] = None
237+
self, image: Optional[Image] = None, contextual_text: Optional[str] = None
238238
) -> "MultiModalEmbeddingResponse":
239239
"""Gets embedding vectors from the provided image.
240240
241241
Args:
242242
image (Image):
243-
The image to generate embeddings for.
243+
Optional. The image to generate embeddings for. One of `image` or `contextual_text` is required.
244244
contextual_text (str):
245245
Optional. Contextual text for your input image. If provided, the model will also
246246
generate an embedding vector for the provided contextual text. The returned image
247247
and text embedding vectors are in the same semantic space with the same dimensionality,
248248
and the vectors can be used interchangeably for use cases like searching image by text
249-
or searching text by image.
249+
or searching text by image. One of `image` or `contextual_text` is required.
250250
251251
Returns:
252252
ImageEmbeddingResponse:
253253
The image and text embedding vectors.
254254
"""
255255

256-
instance = {
257-
"image": {"bytesBase64Encoded": image._as_base64_string()},
258-
}
256+
if not image and not contextual_text:
257+
raise ValueError("One of `image` or `contextual_text` is required.")
258+
259+
instance = {}
260+
261+
if image:
262+
instance["image"] = {"bytesBase64Encoded": image._as_base64_string()}
259263

260264
if contextual_text:
261265
instance["text"] = contextual_text
@@ -280,11 +284,11 @@ class MultiModalEmbeddingResponse:
280284
281285
Attributes:
282286
image_embedding (List[float]):
283-
The emebedding vector generated from your image.
287+
Optional. The embedding vector generated from your image.
284288
text_embedding (List[float]):
285289
Optional. The embedding vector generated from the contextual text provided for your image.
286290
"""
287291

288-
image_embedding: List[float]
289292
_prediction_response: Any
293+
image_embedding: Optional[List[float]] = None
290294
text_embedding: Optional[List[float]] = None

0 commit comments

Comments
 (0)