Skip to content

Commit 1d9bd23

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LVM - Added the MultiModalEmbeddingModel.get_embeddings(dimension=...) parameter
PiperOrigin-RevId: 599457605
1 parent 80d5c56 commit 1d9bd23

File tree

2 files changed

+63
-11
lines changed

2 files changed

+63
-11
lines changed

tests/unit/aiplatform/test_vision_models.py

+37
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,43 @@ def test_image_embedding_model_with_only_text(self):
684684
assert not embedding_response.image_embedding
685685
assert embedding_response.text_embedding == test_embeddings
686686

687+
def test_image_embedding_model_with_lower_dimensions(self):
688+
aiplatform.init(
689+
project=_TEST_PROJECT,
690+
location=_TEST_LOCATION,
691+
)
692+
with mock.patch.object(
693+
target=model_garden_service_client.ModelGardenServiceClient,
694+
attribute="get_publisher_model",
695+
return_value=gca_publisher_model.PublisherModel(
696+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
697+
),
698+
):
699+
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
700+
"multimodalembedding@001"
701+
)
702+
703+
dimension = 128
704+
test_embeddings = [0] * dimension
705+
gca_predict_response = gca_prediction_service.PredictResponse()
706+
gca_predict_response.predictions.append(
707+
{"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings}
708+
)
709+
710+
image = generate_image_from_file()
711+
712+
with mock.patch.object(
713+
target=prediction_service_client.PredictionServiceClient,
714+
attribute="predict",
715+
return_value=gca_predict_response,
716+
):
717+
embedding_response = model.get_embeddings(
718+
image=image, contextual_text="hello world", dimension=dimension
719+
)
720+
721+
assert embedding_response.image_embedding == test_embeddings
722+
assert embedding_response.text_embedding == test_embeddings
723+
687724

688725
@pytest.mark.usefixtures("google_auth_mock")
689726
class ImageTextModelTests:

vertexai/vision_models/_vision_models.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _generate_images(
142142
seed: Optional[int] = None,
143143
base_image: Optional["Image"] = None,
144144
mask: Optional["Image"] = None,
145-
language:Optional[str] = None,
145+
language: Optional[str] = None,
146146
) -> "ImageGenerationResponse":
147147
"""Generates images from text prompt.
148148
@@ -641,19 +641,27 @@ class MultiModalEmbeddingModel(_model_garden_models._ModelGardenModel):
641641
)
642642

643643
def get_embeddings(
644-
self, image: Optional[Image] = None, contextual_text: Optional[str] = None
644+
self,
645+
image: Optional[Image] = None,
646+
contextual_text: Optional[str] = None,
647+
dimension: Optional[int] = None,
645648
) -> "MultiModalEmbeddingResponse":
646649
"""Gets embedding vectors from the provided image.
647650
648651
Args:
649-
image (Image):
650-
Optional. The image to generate embeddings for. One of `image` or `contextual_text` is required.
651-
contextual_text (str):
652-
Optional. Contextual text for your input image. If provided, the model will also
653-
generate an embedding vector for the provided contextual text. The returned image
654-
and text embedding vectors are in the same semantic space with the same dimensionality,
655-
and the vectors can be used interchangeably for use cases like searching image by text
656-
or searching text by image. One of `image` or `contextual_text` is required.
652+
image (Image): Optional. The image to generate embeddings for. One of
653+
`image` or `contextual_text` is required.
654+
contextual_text (str): Optional. Contextual text for your input image.
655+
If provided, the model will also generate an embedding vector for the
656+
provided contextual text. The returned image and text embedding
657+
vectors are in the same semantic space with the same dimensionality,
658+
and the vectors can be used interchangeably for use cases like
659+
searching image by text or searching text by image. One of `image` or
660+
`contextual_text` is required.
661+
dimension (int): Optional. The number of embedding dimensions. Lower
662+
values offer decreased latency when using these embeddings for
663+
subsequent tasks, while higher values offer better accuracy. Available
664+
values: `128`, `256`, `512`, and `1408` (default).
657665
658666
Returns:
659667
ImageEmbeddingResponse:
@@ -671,7 +679,14 @@ def get_embeddings(
671679
if contextual_text:
672680
instance["text"] = contextual_text
673681

674-
response = self._endpoint.predict(instances=[instance])
682+
parameters = {}
683+
if dimension:
684+
parameters["dimension"] = dimension
685+
686+
response = self._endpoint.predict(
687+
instances=[instance],
688+
parameters=parameters,
689+
)
675690
image_embedding = response.predictions[0].get("imageEmbedding")
676691
text_embedding = (
677692
response.predictions[0].get("textEmbedding")

0 commit comments

Comments
 (0)