@@ -1398,3 +1398,209 @@ def verify_image(self, image: Image) -> WatermarkVerificationResponse:
1398
1398
_prediction_response = response ,
1399
1399
watermark_verification_result = verification_likelihood ,
1400
1400
)
1401
+
1402
+
1403
+ class Scribble :
1404
+ """Input scribble for image segmentation."""
1405
+
1406
+ __module__ = "vertexai.preview.vision_models"
1407
+
1408
+ _image_ : Optional [Image ] = None
1409
+
1410
+ def __init__ (
1411
+ self ,
1412
+ image_bytes : Optional [bytes ],
1413
+ gcs_uri : Optional [str ] = None ,
1414
+ ):
1415
+ """Creates a `Scribble` object.
1416
+
1417
+ Args:
1418
+ image_bytes: Mask image file bytes.
1419
+ gcs_uri: Mask image file Google Cloud Storage uri.
1420
+ """
1421
+ if bool (image_bytes ) == bool (gcs_uri ):
1422
+ raise ValueError ("Either image_bytes or gcs_uri must be provided." )
1423
+
1424
+ self ._image_ = Image (image_bytes , gcs_uri )
1425
+
1426
+ @property
1427
+ def image (self ) -> Optional [Image ]:
1428
+ """The scribble image."""
1429
+ return self ._image_
1430
+
1431
+
1432
+ @dataclasses .dataclass
1433
+ class EntityLabel :
1434
+ """Entity label holding a text label and any associated confidence score."""
1435
+
1436
+ __module__ = "vertexai.preview.vision_models"
1437
+
1438
+ label : Optional [str ] = None
1439
+ score : Optional [float ] = None
1440
+
1441
+
1442
+ class GeneratedMask (Image ):
1443
+ """Generated image mask."""
1444
+
1445
+ __module__ = "vertexai.preview.vision_models"
1446
+
1447
+ __labels__ : Optional [List [EntityLabel ]] = None
1448
+
1449
+ def __init__ (
1450
+ self ,
1451
+ image_bytes : Optional [bytes ],
1452
+ gcs_uri : Optional [str ] = None ,
1453
+ labels : Optional [List [EntityLabel ]] = None ,
1454
+ ):
1455
+ """Creates a `GeneratedMask` object.
1456
+
1457
+ Args:
1458
+ image_bytes: Mask image file bytes.
1459
+ gcs_uri: Mask image file Google Cloud Storage uri.
1460
+ labels: Generated entity labels. Each text label might be associated
1461
+ with a confidence score.
1462
+ """
1463
+
1464
+ super ().__init__ (
1465
+ image_bytes = image_bytes ,
1466
+ gcs_uri = gcs_uri ,
1467
+ )
1468
+ self .__labels__ = labels
1469
+
1470
+ @property
1471
+ def labels (self ) -> Optional [List [EntityLabel ]]:
1472
+ """The entity labels of the masked object."""
1473
+ return self .__labels__
1474
+
1475
+
1476
+ @dataclasses .dataclass
1477
+ class ImageSegmentationResponse :
1478
+ """Image Segmentation response.
1479
+
1480
+ Attributes:
1481
+ masks: The list of generated masks.
1482
+ """
1483
+
1484
+ __module__ = "vertexai.preview.vision_models"
1485
+
1486
+ _prediction_response : Any
1487
+ masks : List [GeneratedMask ]
1488
+
1489
+ def __iter__ (self ) -> typing .Iterator [GeneratedMask ]:
1490
+ """Iterates through the generated masks."""
1491
+ yield from self .masks
1492
+
1493
+ def __getitem__ (self , idx : int ) -> GeneratedMask :
1494
+ """Gets the generated masks by index."""
1495
+ return self .masks [idx ]
1496
+
1497
+
1498
+ class ImageSegmentationModel (_model_garden_models ._ModelGardenModel ):
1499
+ """Segments an image."""
1500
+
1501
+ __module__ = "vertexai.preview.vision_models"
1502
+
1503
+ _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/image_segmentation_model_1.0.0.yaml"
1504
+
1505
+ def segment_image (
1506
+ self ,
1507
+ base_image : Image ,
1508
+ prompt : Optional [str ] = None ,
1509
+ scribble : Optional [Scribble ] = None ,
1510
+ mode : Literal [
1511
+ "foreground" , "background" , "semantic" , "prompt" , "interactive"
1512
+ ] = "foreground" ,
1513
+ max_predictions : Optional [int ] = None ,
1514
+ confidence_threshold : Optional [float ] = 0.1 ,
1515
+ mask_dilation : Optional [float ] = None ,
1516
+ ) -> ImageSegmentationResponse :
1517
+ """Segments an image.
1518
+
1519
+ Args:
1520
+ base_image: The base image to segment.
1521
+ prompt: The prompt to guide the segmentation. Valid for the prompt and
1522
+ semantic modes.
1523
+ scribble: The scribble in the form of an image mask to guide the
1524
+ segmentation. Valid for the interactive mode. The scribble image
1525
+ should be a black-and-white PNG file equal in size to the base
1526
+ image. White pixels represent the scribbled brush stroke which
1527
+ select objects in the base image to segment.
1528
+ mode: The segmentation mode. Supported values are:
1529
+ * foreground: segment the foreground object of an image
1530
+ * background: segment the background of an image
1531
+ * semantic: specify the objects to segment with a comma delimited
1532
+ list of objects from the class set in the prompt.
1533
+ * prompt: use an open-vocabulary text prompt to select objects to
1534
+ segment.
1535
+ * interactive: draw scribbles with a brush stroke to guide the
1536
+ segmentation. The default is foreground.
1537
+ max_predictions: The maximum number of predictions to make. Valid for
1538
+ the prompt mode. Default is unlimited.
1539
+ confidence_threshold: A threshold to filter predictions by confidence
1540
+ score. The value must be in the range of 0.0 and 1.0. The default is
1541
+ 0.1.
1542
+ mask_dilation: A value to dilate the masks by. The value must be in the
1543
+ range of 0.0 (no dilation) and 1.0 (the whole image will be masked).
1544
+ The default is 0.0.
1545
+
1546
+ Returns:
1547
+ An `ImageSegmentationResponse` object with the generated masks,
1548
+ entities, and labels (if any).
1549
+ """
1550
+ if not base_image :
1551
+ raise ValueError ("Base image is required." )
1552
+ instance = {}
1553
+
1554
+ if base_image ._gcs_uri :
1555
+ instance ["image" ] = {"gcsUri" : base_image ._gcs_uri }
1556
+ else :
1557
+ instance ["image" ] = {"bytesBase64Encoded" : base_image ._as_base64_string ()}
1558
+
1559
+ if prompt :
1560
+ instance ["prompt" ] = prompt
1561
+
1562
+ parameters = {}
1563
+ if scribble and scribble .image :
1564
+ scribble_image = scribble .image
1565
+ if scribble_image ._gcs_uri :
1566
+ instance ["scribble" ] = {"image" : {"gcsUri" : scribble_image ._gcs_uri }}
1567
+ else :
1568
+ instance ["scribble" ] = {
1569
+ "image" : {"bytesBase64Encoded" : scribble_image ._as_base64_string ()}
1570
+ }
1571
+ parameters ["mode" ] = mode
1572
+ if max_predictions :
1573
+ parameters ["maxPredictions" ] = max_predictions
1574
+ if confidence_threshold :
1575
+ parameters ["confidenceThreshold" ] = confidence_threshold
1576
+ if mask_dilation :
1577
+ parameters ["maskDilation" ] = mask_dilation
1578
+
1579
+ response = self ._endpoint .predict (
1580
+ instances = [instance ],
1581
+ parameters = parameters ,
1582
+ )
1583
+
1584
+ masks : List [GeneratedMask ] = []
1585
+ for prediction in response .predictions :
1586
+ encoded_bytes = prediction .get ("bytesBase64Encoded" )
1587
+ labels = []
1588
+ if "labels" in prediction :
1589
+ for label in prediction ["labels" ]:
1590
+ labels .append (
1591
+ EntityLabel (
1592
+ label = label .get ("label" ),
1593
+ score = label .get ("score" ),
1594
+ )
1595
+ )
1596
+ generated_image = GeneratedMask (
1597
+ image_bytes = base64 .b64decode (encoded_bytes ) if encoded_bytes else None ,
1598
+ gcs_uri = prediction .get ("gcsUri" ),
1599
+ labels = labels ,
1600
+ )
1601
+ masks .append (generated_image )
1602
+
1603
+ return ImageSegmentationResponse (
1604
+ _prediction_response = response ,
1605
+ masks = masks ,
1606
+ )
0 commit comments