Skip to content

Commit ae63a43

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Vision Models - onboard Image Segmentation.
Generates masks by segmenting a base image. Supports several different modes and input modalities, along with parameters to customize the prediction response. More information is available in the model card at https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/image-segmentation-001 PiperOrigin-RevId: 686264882
1 parent 507e988 commit ae63a43

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

tests/unit/aiplatform/test_vision_models.py

+75
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,19 @@
101101
},
102102
}
103103

104+
_IMAGE_SEGMENTATION_PUBLISHER_MODEL_DICT = {
105+
"name": "publishers/google/models/image-segmentation-001",
106+
"version_id": "default",
107+
"open_source_category": "PROPRIETARY",
108+
"launch_stage": (gca_publisher_model.PublisherModel.LaunchStage.PRIVATE_PREVIEW),
109+
"publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/image-segmentation-001",
110+
"predict_schemata": {
111+
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/image_segmentation_model_1.0.0.yaml",
112+
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/image_segmentation_model_1.0.0.yaml",
113+
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/image_segmentation_model_1.0.0.yaml",
114+
},
115+
}
116+
104117

105118
def make_image_base64(width: int, height: int) -> str:
106119
image: PIL_Image.Image = PIL_Image.new(mode="RGB", size=(width, height))
@@ -173,6 +186,20 @@ def make_image_upscale_response_gcs() -> Dict[str, Any]:
173186
return {"predictions": [predictions]}
174187

175188

189+
def make_image_segmentation_response(
190+
width: int, height: int, count: int = 1
191+
) -> Dict[str, Any]:
192+
predictions = []
193+
for _ in range(count):
194+
predictions.append(
195+
{
196+
"bytesBase64Encoded": make_image_base64(width, height),
197+
"mimeType": "image/png",
198+
}
199+
)
200+
return {"predictions": predictions}
201+
202+
176203
def generate_image_from_file(
177204
width: int = 100, height: int = 100
178205
) -> ga_vision_models.Image:
@@ -1018,6 +1045,54 @@ def test_get_image_verification_results(self):
10181045
assert actual_results == [gca_prediction_response, "REJECT"]
10191046

10201047

1048+
@pytest.mark.usefixtures("google_auth_mock")
1049+
class ImageSegmentationModelTests:
1050+
"""Unit tests for the image segmentation models."""
1051+
1052+
def setup_method(self):
1053+
importlib.reload(initializer)
1054+
importlib.reload(aiplatform)
1055+
1056+
def teardown_method(self):
1057+
initializer.global_pool.shutdown(wait=True)
1058+
1059+
def test_get_image_segmentation_results(self):
1060+
"""Tests the image segmentation model."""
1061+
aiplatform.init(
1062+
project=_TEST_PROJECT,
1063+
location=_TEST_LOCATION,
1064+
)
1065+
with mock.patch.object(
1066+
target=model_garden_service_client.ModelGardenServiceClient,
1067+
attribute="get_publisher_model",
1068+
return_value=gca_publisher_model.PublisherModel(
1069+
_IMAGE_SEGMENTATION_PUBLISHER_MODEL_DICT
1070+
),
1071+
) as mock_get_publisher_model:
1072+
model = ga_vision_models.ImageSegmentationModel.from_pretrained(
1073+
"image-segmentation-001"
1074+
)
1075+
mock_get_publisher_model.assert_called_once_with(
1076+
name="publishers/google/models/image-segmentation-001",
1077+
retry=base._DEFAULT_RETRY,
1078+
)
1079+
1080+
image = generate_image_from_file()
1081+
image_segmentation_response = make_image_segmentation_response(640, 640)
1082+
gca_prediction_response = gca_prediction_service.PredictResponse()
1083+
gca_prediction_response.predictions.append(
1084+
image_segmentation_response["predictions"]
1085+
)
1086+
1087+
with mock.patch.object(
1088+
target=prediction_service_client.PredictionServiceClient,
1089+
attribute="predict",
1090+
return_value=gca_prediction_response,
1091+
):
1092+
segmentation_response = model.segment_image(base_image=image)
1093+
assert len(segmentation_response) == 1
1094+
1095+
10211096
@pytest.mark.usefixtures("google_auth_mock")
10221097
class TestMultiModalEmbeddingModels:
10231098
"""Unit tests for the image generation models."""

vertexai/preview/vision_models.py

+10
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@
1515
"""Classes for working with vision models."""
1616

1717
from vertexai.vision_models._vision_models import (
18+
EntityLabel,
1819
GeneratedImage,
20+
GeneratedMask,
1921
Image,
2022
ImageCaptioningModel,
2123
ImageGenerationModel,
2224
ImageGenerationResponse,
2325
ImageQnAModel,
26+
ImageSegmentationModel,
27+
ImageSegmentationResponse,
2428
ImageTextModel,
2529
MultiModalEmbeddingModel,
2630
MultiModalEmbeddingResponse,
31+
Scribble,
2732
Video,
2833
VideoEmbedding,
2934
VideoSegmentConfig,
@@ -32,16 +37,21 @@
3237
)
3338

3439
__all__ = [
40+
"EntityLabel",
41+
"GeneratedMask",
3542
"Image",
3643
"ImageGenerationModel",
3744
"ImageGenerationResponse",
3845
"ImageCaptioningModel",
3946
"ImageQnAModel",
47+
"ImageSegmentationModel",
48+
"ImageSegmentationResponse",
4049
"ImageTextModel",
4150
"WatermarkVerificationModel",
4251
"GeneratedImage",
4352
"MultiModalEmbeddingModel",
4453
"MultiModalEmbeddingResponse",
54+
"Scribble",
4555
"Video",
4656
"VideoEmbedding",
4757
"VideoSegmentConfig",

vertexai/vision_models/_vision_models.py

+206
Original file line numberDiff line numberDiff line change
@@ -1398,3 +1398,209 @@ def verify_image(self, image: Image) -> WatermarkVerificationResponse:
13981398
_prediction_response=response,
13991399
watermark_verification_result=verification_likelihood,
14001400
)
1401+
1402+
1403+
class Scribble:
1404+
"""Input scribble for image segmentation."""
1405+
1406+
__module__ = "vertexai.preview.vision_models"
1407+
1408+
_image_: Optional[Image] = None
1409+
1410+
def __init__(
1411+
self,
1412+
image_bytes: Optional[bytes],
1413+
gcs_uri: Optional[str] = None,
1414+
):
1415+
"""Creates a `Scribble` object.
1416+
1417+
Args:
1418+
image_bytes: Mask image file bytes.
1419+
gcs_uri: Mask image file Google Cloud Storage uri.
1420+
"""
1421+
if bool(image_bytes) == bool(gcs_uri):
1422+
raise ValueError("Either image_bytes or gcs_uri must be provided.")
1423+
1424+
self._image_ = Image(image_bytes, gcs_uri)
1425+
1426+
@property
1427+
def image(self) -> Optional[Image]:
1428+
"""The scribble image."""
1429+
return self._image_
1430+
1431+
1432+
@dataclasses.dataclass
1433+
class EntityLabel:
1434+
"""Entity label holding a text label and any associated confidence score."""
1435+
1436+
__module__ = "vertexai.preview.vision_models"
1437+
1438+
label: Optional[str] = None
1439+
score: Optional[float] = None
1440+
1441+
1442+
class GeneratedMask(Image):
1443+
"""Generated image mask."""
1444+
1445+
__module__ = "vertexai.preview.vision_models"
1446+
1447+
__labels__: Optional[List[EntityLabel]] = None
1448+
1449+
def __init__(
1450+
self,
1451+
image_bytes: Optional[bytes],
1452+
gcs_uri: Optional[str] = None,
1453+
labels: Optional[List[EntityLabel]] = None,
1454+
):
1455+
"""Creates a `GeneratedMask` object.
1456+
1457+
Args:
1458+
image_bytes: Mask image file bytes.
1459+
gcs_uri: Mask image file Google Cloud Storage uri.
1460+
labels: Generated entity labels. Each text label might be associated
1461+
with a confidence score.
1462+
"""
1463+
1464+
super().__init__(
1465+
image_bytes=image_bytes,
1466+
gcs_uri=gcs_uri,
1467+
)
1468+
self.__labels__ = labels
1469+
1470+
@property
1471+
def labels(self) -> Optional[List[EntityLabel]]:
1472+
"""The entity labels of the masked object."""
1473+
return self.__labels__
1474+
1475+
1476+
@dataclasses.dataclass
1477+
class ImageSegmentationResponse:
1478+
"""Image Segmentation response.
1479+
1480+
Attributes:
1481+
masks: The list of generated masks.
1482+
"""
1483+
1484+
__module__ = "vertexai.preview.vision_models"
1485+
1486+
_prediction_response: Any
1487+
masks: List[GeneratedMask]
1488+
1489+
def __iter__(self) -> typing.Iterator[GeneratedMask]:
1490+
"""Iterates through the generated masks."""
1491+
yield from self.masks
1492+
1493+
def __getitem__(self, idx: int) -> GeneratedMask:
1494+
"""Gets the generated masks by index."""
1495+
return self.masks[idx]
1496+
1497+
1498+
class ImageSegmentationModel(_model_garden_models._ModelGardenModel):
1499+
"""Segments an image."""
1500+
1501+
__module__ = "vertexai.preview.vision_models"
1502+
1503+
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/image_segmentation_model_1.0.0.yaml"
1504+
1505+
def segment_image(
1506+
self,
1507+
base_image: Image,
1508+
prompt: Optional[str] = None,
1509+
scribble: Optional[Scribble] = None,
1510+
mode: Literal[
1511+
"foreground", "background", "semantic", "prompt", "interactive"
1512+
] = "foreground",
1513+
max_predictions: Optional[int] = None,
1514+
confidence_threshold: Optional[float] = 0.1,
1515+
mask_dilation: Optional[float] = None,
1516+
) -> ImageSegmentationResponse:
1517+
"""Segments an image.
1518+
1519+
Args:
1520+
base_image: The base image to segment.
1521+
prompt: The prompt to guide the segmentation. Valid for the prompt and
1522+
semantic modes.
1523+
scribble: The scribble in the form of an image mask to guide the
1524+
segmentation. Valid for the interactive mode. The scribble image
1525+
should be a black-and-white PNG file equal in size to the base
1526+
image. White pixels represent the scribbled brush stroke which
1527+
select objects in the base image to segment.
1528+
mode: The segmentation mode. Supported values are:
1529+
* foreground: segment the foreground object of an image
1530+
* background: segment the background of an image
1531+
* semantic: specify the objects to segment with a comma delimited
1532+
list of objects from the class set in the prompt.
1533+
* prompt: use an open-vocabulary text prompt to select objects to
1534+
segment.
1535+
* interactive: draw scribbles with a brush stroke to guide the
1536+
segmentation. The default is foreground.
1537+
max_predictions: The maximum number of predictions to make. Valid for
1538+
the prompt mode. Default is unlimited.
1539+
confidence_threshold: A threshold to filter predictions by confidence
1540+
score. The value must be in the range of 0.0 and 1.0. The default is
1541+
0.1.
1542+
mask_dilation: A value to dilate the masks by. The value must be in the
1543+
range of 0.0 (no dilation) and 1.0 (the whole image will be masked).
1544+
The default is 0.0.
1545+
1546+
Returns:
1547+
An `ImageSegmentationResponse` object with the generated masks,
1548+
entities, and labels (if any).
1549+
"""
1550+
if not base_image:
1551+
raise ValueError("Base image is required.")
1552+
instance = {}
1553+
1554+
if base_image._gcs_uri:
1555+
instance["image"] = {"gcsUri": base_image._gcs_uri}
1556+
else:
1557+
instance["image"] = {"bytesBase64Encoded": base_image._as_base64_string()}
1558+
1559+
if prompt:
1560+
instance["prompt"] = prompt
1561+
1562+
parameters = {}
1563+
if scribble and scribble.image:
1564+
scribble_image = scribble.image
1565+
if scribble_image._gcs_uri:
1566+
instance["scribble"] = {"image": {"gcsUri": scribble_image._gcs_uri}}
1567+
else:
1568+
instance["scribble"] = {
1569+
"image": {"bytesBase64Encoded": scribble_image._as_base64_string()}
1570+
}
1571+
parameters["mode"] = mode
1572+
if max_predictions:
1573+
parameters["maxPredictions"] = max_predictions
1574+
if confidence_threshold:
1575+
parameters["confidenceThreshold"] = confidence_threshold
1576+
if mask_dilation:
1577+
parameters["maskDilation"] = mask_dilation
1578+
1579+
response = self._endpoint.predict(
1580+
instances=[instance],
1581+
parameters=parameters,
1582+
)
1583+
1584+
masks: List[GeneratedMask] = []
1585+
for prediction in response.predictions:
1586+
encoded_bytes = prediction.get("bytesBase64Encoded")
1587+
labels = []
1588+
if "labels" in prediction:
1589+
for label in prediction["labels"]:
1590+
labels.append(
1591+
EntityLabel(
1592+
label=label.get("label"),
1593+
score=label.get("score"),
1594+
)
1595+
)
1596+
generated_image = GeneratedMask(
1597+
image_bytes=base64.b64decode(encoded_bytes) if encoded_bytes else None,
1598+
gcs_uri=prediction.get("gcsUri"),
1599+
labels=labels,
1600+
)
1601+
masks.append(generated_image)
1602+
1603+
return ImageSegmentationResponse(
1604+
_prediction_response=response,
1605+
masks=masks,
1606+
)

0 commit comments

Comments
 (0)