Skip to content

Commit 9bbf1ea

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LVM - Released the Large Vision Models SDK to public preview
# Large Vision Models ## Features: * Image captioning * Image Q&A * Multimodal embedding ## Usage: ### Image captioning ```python from vertexai.preview.vision_models import Image, ImageCaptioningModel model = ImageCaptioningModel.from_pretrained("imagetext@001") image = Image.load_from_file("image.png") captions = model.get_captions( image=image, # Optional: number_of_results=1, language="en", ) print(captions) ``` ### Image Q&A ```python from vertexai.preview.vision_models import Image, ImageQnAModel model = ImageQnAModel.from_pretrained("imagetext@001") image = Image.load_from_file("image.png") answers = model.ask_question( image=image, question="What color is the car in this image?", # Optional: number_of_results=1, ) print(answers) ``` ### Multimodal embedding ```python from vertexai.preview.vision_models import Image, MultiModalEmbeddingModel model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001") image = Image.load_from_file("image.png") embeddings = model.get_embeddings( image=image, # Optional: contextual_text="this is a car" ) print(len(embeddings.image_embedding)) print(len(embeddings.text_embedding)) ``` PiperOrigin-RevId: 550804479
1 parent e4b23a2 commit 9bbf1ea

File tree

4 files changed

+673
-0
lines changed

4 files changed

+673
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# pylint: disable=protected-access
19+
20+
import os
21+
import tempfile
22+
23+
from google.cloud import aiplatform
24+
from tests.system.aiplatform import e2e_base
25+
from vertexai.preview import vision_models
26+
from PIL import Image as PIL_Image
27+
28+
29+
def _create_blank_image(
30+
width: int = 100,
31+
height: int = 100,
32+
) -> vision_models.Image:
33+
with tempfile.TemporaryDirectory() as temp_dir:
34+
image_path = os.path.join(temp_dir, "image.png")
35+
pil_image = PIL_Image.new(mode="RGB", size=(width, height))
36+
pil_image.save(image_path, format="PNG")
37+
return vision_models.Image.load_from_file(image_path)
38+
39+
40+
class VisionModelTestSuite(e2e_base.TestEndToEnd):
41+
"""System tests for vision models."""
42+
43+
_temp_prefix = "temp_vision_models_test_"
44+
45+
def test_image_captioning_model_get_captions(self):
46+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
47+
48+
model = vision_models.ImageCaptioningModel.from_pretrained("imagetext")
49+
image = _create_blank_image()
50+
captions = model.get_captions(
51+
image=image,
52+
# Optional:
53+
number_of_results=2,
54+
language="en",
55+
)
56+
assert len(captions) == 2
57+
58+
def test_image_q_and_a_model_ask_question(self):
59+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
60+
61+
model = vision_models.ImageQnAModel.from_pretrained("imagetext")
62+
image = _create_blank_image()
63+
answers = model.ask_question(
64+
image=image,
65+
question="What color is the car in this image?",
66+
# Optional:
67+
number_of_results=2,
68+
)
69+
assert len(answers) == 2
70+
71+
def test_multi_modal_embedding_model(self):
72+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
73+
74+
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
75+
"multimodalembedding@001"
76+
)
77+
image = _create_blank_image()
78+
embeddings = model.get_embeddings(
79+
image=image,
80+
# Optional:
81+
contextual_text="this is a car",
82+
)
83+
# The service is expected to return the embeddings of size 1408
84+
assert len(embeddings.image_embedding) == 1408
85+
assert len(embeddings.text_embedding) == 1408
+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Unit tests for the vision models."""
17+
18+
# pylint: disable=protected-access,bad-continuation
19+
20+
import importlib
21+
import os
22+
import tempfile
23+
from unittest import mock
24+
25+
from google.cloud import aiplatform
26+
from google.cloud.aiplatform import base
27+
from google.cloud.aiplatform import initializer
28+
from google.cloud.aiplatform.compat.services import (
29+
model_garden_service_client,
30+
)
31+
from google.cloud.aiplatform.compat.services import prediction_service_client
32+
from google.cloud.aiplatform.compat.types import (
33+
prediction_service as gca_prediction_service,
34+
)
35+
from google.cloud.aiplatform.compat.types import (
36+
publisher_model as gca_publisher_model,
37+
)
38+
from vertexai.preview import vision_models
39+
40+
from PIL import Image as PIL_Image
41+
import pytest
42+
43+
_TEST_PROJECT = "test-project"
44+
_TEST_LOCATION = "us-central1"
45+
46+
_IMAGE_TEXT_PUBLISHER_MODEL_DICT = {
47+
"name": "publishers/google/models/imagetext",
48+
"version_id": "001",
49+
"open_source_category": "PROPRIETARY",
50+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
51+
"publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/imagetext@001",
52+
"predict_schemata": {
53+
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml",
54+
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/vision_reasoning_model_1.0.0.yaml",
55+
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/vision_reasoning_model_1.0.0.yaml",
56+
},
57+
}
58+
59+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT = {
60+
"name": "publishers/google/models/multimodalembedding",
61+
"version_id": "001",
62+
"open_source_category": "PROPRIETARY",
63+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
64+
"publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/multimodalembedding@001",
65+
"predict_schemata": {
66+
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/vision_embedding_model_1.0.0.yaml",
67+
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/vision_embedding_model_1.0.0.yaml",
68+
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/vision_embedding_model_1.0.0.yaml",
69+
},
70+
}
71+
72+
73+
def generate_image_from_file(
74+
width: int = 100, height: int = 100
75+
) -> vision_models.Image:
76+
with tempfile.TemporaryDirectory() as temp_dir:
77+
image_path = os.path.join(temp_dir, "image.png")
78+
pil_image = PIL_Image.new(mode="RGB", size=(width, height))
79+
pil_image.save(image_path, format="PNG")
80+
return vision_models.Image.load_from_file(image_path)
81+
82+
83+
@pytest.mark.usefixtures("google_auth_mock")
84+
class ImageCaptioningModelTests:
85+
"""Unit tests for the image captioning models."""
86+
87+
def setup_method(self):
88+
importlib.reload(initializer)
89+
importlib.reload(aiplatform)
90+
91+
def teardown_method(self):
92+
initializer.global_pool.shutdown(wait=True)
93+
94+
def test_get_captions(self):
95+
"""Tests the image captioning model."""
96+
aiplatform.init(
97+
project=_TEST_PROJECT,
98+
location=_TEST_LOCATION,
99+
)
100+
with mock.patch.object(
101+
target=model_garden_service_client.ModelGardenServiceClient,
102+
attribute="get_publisher_model",
103+
return_value=gca_publisher_model(_IMAGE_TEXT_PUBLISHER_MODEL_DICT),
104+
):
105+
model = vision_models.ImageCaptioningModel.from_pretrained("imagetext@001")
106+
107+
image_captions = [
108+
"Caption 1",
109+
"Caption 2",
110+
]
111+
gca_predict_response = gca_prediction_service.PredictResponse()
112+
gca_predict_response.predictions.extend(image_captions)
113+
114+
with tempfile.TemporaryDirectory() as temp_dir:
115+
image_path = os.path.join(temp_dir, "image.png")
116+
pil_image = PIL_Image.new(mode="RGB", size=(100, 100))
117+
pil_image.save(image_path, format="PNG")
118+
image = vision_models.Image.load_from_file(image_path)
119+
120+
with mock.patch.object(
121+
target=prediction_service_client.PredictionServiceClient,
122+
attribute="predict",
123+
return_value=gca_predict_response,
124+
):
125+
actual_captions = model.get_captions(image=image, number_of_results=2)
126+
assert actual_captions == image_captions
127+
128+
129+
@pytest.mark.usefixtures("google_auth_mock")
130+
class ImageQnAModelTests:
131+
"""Unit tests for the image to text models."""
132+
133+
def setup_method(self):
134+
importlib.reload(initializer)
135+
importlib.reload(aiplatform)
136+
137+
def teardown_method(self):
138+
initializer.global_pool.shutdown(wait=True)
139+
140+
def test_get_captions(self):
141+
"""Tests the image captioning model."""
142+
aiplatform.init(
143+
project=_TEST_PROJECT,
144+
location=_TEST_LOCATION,
145+
)
146+
with mock.patch.object(
147+
target=model_garden_service_client.ModelGardenServiceClient,
148+
attribute="get_publisher_model",
149+
return_value=gca_publisher_model.PublisherModel(
150+
_IMAGE_TEXT_PUBLISHER_MODEL_DICT
151+
),
152+
) as mock_get_publisher_model:
153+
model = vision_models.ImageQnAModel.from_pretrained("imagetext@001")
154+
155+
mock_get_publisher_model.assert_called_once_with(
156+
name="publishers/google/models/imagetext@001",
157+
retry=base._DEFAULT_RETRY,
158+
)
159+
160+
image_answers = [
161+
"Black square",
162+
"Black Square by Malevich",
163+
]
164+
gca_predict_response = gca_prediction_service.PredictResponse()
165+
gca_predict_response.predictions.extend(image_answers)
166+
167+
image = generate_image_from_file()
168+
169+
with mock.patch.object(
170+
target=prediction_service_client.PredictionServiceClient,
171+
attribute="predict",
172+
return_value=gca_predict_response,
173+
):
174+
actual_answers = model.ask_question(
175+
image=image,
176+
question="What is this painting?",
177+
number_of_results=2,
178+
)
179+
assert actual_answers == image_answers
180+
181+
182+
@pytest.mark.usefixtures("google_auth_mock")
183+
class TestMultiModalEmbeddingModels:
184+
"""Unit tests for the image generation models."""
185+
186+
def setup_method(self):
187+
importlib.reload(initializer)
188+
importlib.reload(aiplatform)
189+
190+
def teardown_method(self):
191+
initializer.global_pool.shutdown(wait=True)
192+
193+
def test_image_embedding_model_with_only_image(self):
194+
aiplatform.init(
195+
project=_TEST_PROJECT,
196+
location=_TEST_LOCATION,
197+
)
198+
with mock.patch.object(
199+
target=model_garden_service_client.ModelGardenServiceClient,
200+
attribute="get_publisher_model",
201+
return_value=gca_publisher_model.PublisherModel(
202+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
203+
),
204+
) as mock_get_publisher_model:
205+
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
206+
"multimodalembedding@001"
207+
)
208+
209+
mock_get_publisher_model.assert_called_once_with(
210+
name="publishers/google/models/multimodalembedding@001",
211+
retry=base._DEFAULT_RETRY,
212+
)
213+
214+
test_image_embeddings = [0, 0]
215+
gca_predict_response = gca_prediction_service.PredictResponse()
216+
gca_predict_response.predictions.append(
217+
{"imageEmbedding": test_image_embeddings}
218+
)
219+
220+
image = generate_image_from_file()
221+
222+
with mock.patch.object(
223+
target=prediction_service_client.PredictionServiceClient,
224+
attribute="predict",
225+
return_value=gca_predict_response,
226+
):
227+
embedding_response = model.get_embeddings(image=image)
228+
229+
assert embedding_response.image_embedding == test_image_embeddings
230+
assert not embedding_response.text_embedding
231+
232+
def test_image_embedding_model_with_image_and_text(self):
233+
aiplatform.init(
234+
project=_TEST_PROJECT,
235+
location=_TEST_LOCATION,
236+
)
237+
with mock.patch.object(
238+
target=model_garden_service_client.ModelGardenServiceClient,
239+
attribute="get_publisher_model",
240+
return_value=gca_publisher_model.PublisherModel(
241+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
242+
),
243+
):
244+
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
245+
"multimodalembedding@001"
246+
)
247+
248+
test_embeddings = [0, 0]
249+
gca_predict_response = gca_prediction_service.PredictResponse()
250+
gca_predict_response.predictions.append(
251+
{"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings}
252+
)
253+
254+
image = generate_image_from_file()
255+
256+
with mock.patch.object(
257+
target=prediction_service_client.PredictionServiceClient,
258+
attribute="predict",
259+
return_value=gca_predict_response,
260+
):
261+
embedding_response = model.get_embeddings(
262+
image=image, contextual_text="hello world"
263+
)
264+
265+
assert embedding_response.image_embedding == test_embeddings
266+
assert embedding_response.text_embedding == test_embeddings

vertexai/preview/vision_models.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Classes for working with vision models."""
16+
17+
from vertexai.vision_models._vision_models import (
18+
Image,
19+
ImageCaptioningModel,
20+
ImageQnAModel,
21+
MultiModalEmbeddingModel,
22+
MultiModalEmbeddingResponse,
23+
)
24+
25+
__all__ = [
26+
"Image",
27+
"ImageCaptioningModel",
28+
"ImageQnAModel",
29+
"MultiModalEmbeddingModel",
30+
"MultiModalEmbeddingResponse",
31+
]

0 commit comments

Comments
 (0)