Skip to content

Commit 9cbd510

Browse files
bsekachevAlx-Wo
andauthored
Added Segment Anything interactor for GPU/CPU (#6008)
Idea of the PR is to finish this one #5990 Deploy for GPU: ``./deploy_gpu.sh pytorch/facebookresearch/sam/nuclio/`` Deploy for CPU: ``./deploy_cpu.sh pytorch/facebookresearch/sam/nuclio/`` If you want to use GPU, be sure you setup docker for this [guide](https://github.com/NVIDIA/nvidia-docker/blob/master/README.md#quickstart). Resolved issue #5984 But the interface probably can be improved Co-authored-by: Alx-Wo <[email protected]>
1 parent 6852cae commit 9cbd510

File tree

6 files changed

+242
-0
lines changed

6 files changed

+242
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## \[2.5.0] - Unreleased
99
### Added
1010
- Add support for Azure Blob Storage connection string authentication(<https://github.com/openvinotoolkit/cvat/pull/4649>)
11+
- Added Segment Anything interactor for CPU/GPU (<https://github.com/opencv/cvat/pull/6008>)
1112

1213
### Changed
1314
- Moving a task from a project to another project is disabled (<https://github.com/opencv/cvat/pull/5901>)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ up to 10x. Here is a list of the algorithms we support, and the platforms they c
184184

185185
| Name | Type | Framework | CPU | GPU |
186186
| ------------------------------------------------------------------------------------------------------- | ---------- | ---------- | --- | --- |
187+
| [Segment Anything](/serverless/pytorch/facebookresearch/sam/nuclio/) | interactor | PyTorch | ✔️ | ✔️ |
187188
| [Deep Extreme Cut](/serverless/openvino/dextr/nuclio) | interactor | OpenVINO | ✔️ | |
188189
| [Faster RCNN](/serverless/openvino/omz/public/faster_rcnn_inception_v2_coco/nuclio) | detector | OpenVINO | ✔️ | |
189190
| [Mask RCNN](/serverless/openvino/omz/public/mask_rcnn_inception_resnet_v2_atrous_coco/nuclio) | detector | OpenVINO | ✔️ | |
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (C) 2023 CVAT.ai Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
metadata:
6+
name: pth.facebookresearch.sam.vit_h
7+
namespace: cvat
8+
annotations:
9+
name: Segment Anything
10+
version: 2
11+
type: interactor
12+
spec:
13+
framework: pytorch
14+
min_pos_points: 1
15+
min_neg_points: 0
16+
animated_gif: https://raw.githubusercontent.com/opencv/cvat/develop/site/content/en/images/hrnet_example.gif
17+
help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it
18+
19+
spec:
20+
description: Interactive object segmentation with Segment-Anything
21+
runtime: 'python:3.8'
22+
handler: main:handler
23+
eventTimeout: 30s
24+
env:
25+
- name: PYTHONPATH
26+
value: /opt/nuclio/sam
27+
28+
build:
29+
image: cvat.pth.facebookresearch.sam.vit_h
30+
baseImage: ubuntu:22.04
31+
32+
directives:
33+
preCopy:
34+
# disable interactive frontend
35+
- kind: ENV
36+
value: DEBIAN_FRONTEND=noninteractive
37+
# set workdir
38+
- kind: WORKDIR
39+
value: /opt/nuclio/sam
40+
# install basic deps
41+
- kind: RUN
42+
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6
43+
# install sam deps
44+
- kind: RUN
45+
value: pip3 install torch torchvision torchaudio opencv-python pycocotools matplotlib onnxruntime onnx
46+
# install sam code
47+
- kind: RUN
48+
value: pip3 install git+https://github.com/facebookresearch/segment-anything.git
49+
# download sam weights
50+
- kind: RUN
51+
value: curl -O https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
52+
# map pip3 and python3 to pip and python
53+
- kind: RUN
54+
value: ln -s /usr/bin/pip3 /usr/local/bin/pip && ln -s /usr/bin/python3 /usr/bin/python
55+
triggers:
56+
myHttpTrigger:
57+
maxWorkers: 1
58+
kind: 'http'
59+
workerAvailabilityTimeoutMilliseconds: 10000
60+
attributes:
61+
maxRequestBodySize: 33554432 # 32MB
62+
resources:
63+
limits:
64+
nvidia.com/gpu: 1
65+
66+
platform:
67+
attributes:
68+
restartPolicy:
69+
name: always
70+
maximumRetryCount: 3
71+
mountMode: volume
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (C) 2023 CVAT.ai Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
metadata:
6+
name: pth.facebookresearch.sam.vit_h
7+
namespace: cvat
8+
annotations:
9+
name: Segment Anything
10+
version: 2
11+
type: interactor
12+
spec:
13+
framework: pytorch
14+
min_pos_points: 1
15+
min_neg_points: 0
16+
animated_gif: https://raw.githubusercontent.com/opencv/cvat/develop/site/content/en/images/hrnet_example.gif
17+
help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it
18+
19+
spec:
20+
description: Interactive object segmentation with Segment-Anything
21+
runtime: 'python:3.8'
22+
handler: main:handler
23+
eventTimeout: 30s
24+
env:
25+
- name: PYTHONPATH
26+
value: /opt/nuclio/sam
27+
28+
build:
29+
image: cvat.pth.facebookresearch.sam.vit_h
30+
baseImage: ubuntu:22.04
31+
32+
directives:
33+
preCopy:
34+
# disable interactive frontend
35+
- kind: ENV
36+
value: DEBIAN_FRONTEND=noninteractive
37+
# set workdir
38+
- kind: WORKDIR
39+
value: /opt/nuclio/sam
40+
# install basic deps
41+
- kind: RUN
42+
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6
43+
# install sam deps
44+
- kind: RUN
45+
value: pip3 install torch torchvision torchaudio opencv-python pycocotools matplotlib onnxruntime onnx
46+
# install sam code
47+
- kind: RUN
48+
value: pip3 install git+https://github.com/facebookresearch/segment-anything.git
49+
# download sam weights
50+
- kind: RUN
51+
value: curl -O https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
52+
# map pip3 and python3 to pip and python
53+
- kind: RUN
54+
value: ln -s /usr/bin/pip3 /usr/local/bin/pip && ln -s /usr/bin/python3 /usr/bin/python
55+
triggers:
56+
myHttpTrigger:
57+
maxWorkers: 2
58+
kind: 'http'
59+
workerAvailabilityTimeoutMilliseconds: 10000
60+
attributes:
61+
maxRequestBodySize: 33554432 # 32MB
62+
63+
platform:
64+
attributes:
65+
restartPolicy:
66+
name: always
67+
maximumRetryCount: 3
68+
mountMode: volume
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (C) 2023 CVAT.ai Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import json
6+
import base64
7+
from PIL import Image
8+
import io
9+
from model_handler import ModelHandler
10+
11+
def init_context(context):
12+
context.logger.info("Init context... 0%")
13+
model = ModelHandler()
14+
context.user_data.model = model
15+
context.logger.info("Init context...100%")
16+
17+
def handler(context, event):
18+
context.logger.info("call handler")
19+
data = event.body
20+
pos_points = data["pos_points"]
21+
neg_points = data["neg_points"]
22+
buf = io.BytesIO(base64.b64decode(data["image"]))
23+
image = Image.open(buf)
24+
image = image.convert("RGB") # to make sure image comes in RGB
25+
mask, polygon = context.user_data.model.handle(image, pos_points, neg_points)
26+
return context.Response(body=json.dumps({
27+
'points': polygon,
28+
'mask': mask.tolist(),
29+
}),
30+
headers={},
31+
content_type='application/json',
32+
status_code=200
33+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (C) 2023 CVAT.ai Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import numpy as np
6+
import cv2
7+
import torch
8+
from segment_anything import sam_model_registry, SamPredictor
9+
10+
def convert_mask_to_polygon(mask):
11+
contours = None
12+
if int(cv2.__version__.split('.')[0]) > 3:
13+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[0]
14+
else:
15+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[1]
16+
17+
contours = max(contours, key=lambda arr: arr.size)
18+
if contours.shape.count(1):
19+
contours = np.squeeze(contours)
20+
if contours.size < 3 * 2:
21+
raise Exception('Less then three point have been detected. Can not build a polygon.')
22+
23+
polygon = []
24+
for point in contours:
25+
polygon.append([int(point[0]), int(point[1])])
26+
27+
return polygon
28+
29+
class ModelHandler:
30+
def __init__(self):
31+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32+
self.sam_checkpoint = "/opt/nuclio/sam/sam_vit_h_4b8939.pth"
33+
self.model_type = "vit_h"
34+
self.latest_image = None
35+
self.latest_low_res_masks = None
36+
sam_model = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
37+
sam_model.to(device=self.device)
38+
self.predictor = SamPredictor(sam_model)
39+
40+
def handle(self, image, pos_points, neg_points):
41+
# latest image is kept in memory because function is always run-time after startup
42+
# we use to avoid computing emeddings twice for the same image
43+
is_the_same_image = self.latest_image is not None and np.array_equal(np.array(image), self.latest_image)
44+
if not is_the_same_image:
45+
self.latest_low_res_masks = None
46+
numpy_image = np.array(image)
47+
self.predictor.set_image(numpy_image)
48+
self.latest_image = numpy_image
49+
# we assume that pos_points and neg_points are of type:
50+
# np.array[[x, y], [x, y], ...]
51+
input_points = np.array(pos_points)
52+
input_labels = np.array([1] * len(pos_points))
53+
54+
if len(neg_points):
55+
input_points = np.concatenate([input_points, neg_points], axis=0)
56+
input_labels = np.concatenate([input_labels, np.array([0] * len(neg_points))], axis=0)
57+
58+
masks, _, low_res_masks = self.predictor.predict(
59+
point_coords=input_points,
60+
point_labels=input_labels,
61+
mask_input = self.latest_low_res_masks,
62+
multimask_output=False
63+
)
64+
self.latest_low_res_masks = low_res_masks
65+
object_mask = np.array(masks[0], dtype=np.uint8)
66+
cv2.normalize(object_mask, object_mask, 0, 255, cv2.NORM_MINMAX)
67+
polygon = convert_mask_to_polygon(object_mask)
68+
return object_mask, polygon

0 commit comments

Comments
 (0)