Skip to content

Auto Bounding Box Selection with Segment Anything Model (SAM) #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Annotate Lab is an open-source application designed for image annotation, compri
- [Settings](#settings-documentation-page)
- [Configurations (Optional)](#configurations-optional-documentation-page)
- [Demo](#demo-v20)
- [Auto Bounding Box Selection with Segment Anything Model (SAM)](#auto-bounding-box-selection-with-segment-anything-model-samdocumentation-page)
- [Outputs](#outputs-documentation-page)
- [YOLO Format](#yolo-format-documentation-page)
- [Normalization process of YOLO annotations](#normalization-process-of-yolo-annotations-documentation-page)
Expand Down Expand Up @@ -242,6 +243,7 @@ You can customize various aspects of Annotate-Lab through configuration settings
```python
# config.py
MASK_BACKGROUND_COLOR = (0, 0, 0) # Black background for masks
SAM_MODEL_ENABLED = False # Segment Anything Model for Auto Annotation
```

```Javascript
Expand All @@ -254,6 +256,7 @@ const config = {
CIRCLE: 2,
BOUNDING_BOX: 2
},
SAM_MODEL_ENABLED: false, // displays button that allows auto annotation using SAM
SHOW_CLASS_DISTRIBUTION: true // displays annotated class distribution bar chart
};
```
Expand All @@ -265,6 +268,14 @@ const config = {
</a>
</p>

## Auto Bounding Box Selection with Segment Anything Model (SAM)[[documentation page]](https://annotate-docs.dwaste.live/example/auto-bounding-box-selection-with-segment-anything-model-sam)

Selection of bounding box automatically is made possible with the [Segment Anything Model (SAM)](https://segment-anything.com/). One can toggle this feature from the configuration of server and client. When enabled, a wand icon will appear in the toolbar. Clicking the wand icon will initiate auto-annotation and display the results

<p align="center">
<img src="./sample_images/sam_example.png" alt="auto_annotation"
</p>

## Outputs [[documentation page]](https://annotate-docs.dwaste.live/fundamentals/set-up-and-run/outputs)
Sample of annotated image along with its mask and settings is show below.

Expand Down
19 changes: 19 additions & 0 deletions client/src/Annotator/reducers/general-reducer.js
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,25 @@ export default (state, action) => {
(activeImage.regions || []).filter((r) => !r.highlighted),
)
}
case "AUTO_ANNOTATE_IMAGE": {
const { annotations } = action;
const updatedImages = state.images.map(image => {
const matchingResponse = annotations.find(item => item.image_source === decodeURI(image.src));
if (matchingResponse) {
return {
...image,
regions: matchingResponse.regions,
};
}
return image;
});
return setIn(
state,
['images'],
updatedImages,
);
}

case "HEADER_BUTTON_CLICKED": {
const buttonName = action.buttonName.toLowerCase()
switch (buttonName) {
Expand Down
3 changes: 3 additions & 0 deletions client/src/Localization/translation-de-DE.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ const translationDeDE = {
helptext_tags: "Tags anzeigen/verstecken",
helptext_boundingbox: "Begrenzungsrahmen hinzufügen",
helptext_polypolygon: "Polygon hinzufügen",
helptext_auto_bounding_box: "Automatische Auswahl des Begrenzungsrahmens",
auto_bounding_box_processing: "Wird verarbeitet…",
auto_bounding_box_done: "Auswahl des automatischen Begrenzungsrahmens abgeschlossen",
helptext_circle: "Kreis hinzufügen",
comment_placeholder: "Kommentar hier schreiben...",
image_tags: "Bild-Tags",
Expand Down
3 changes: 3 additions & 0 deletions client/src/Localization/translation-en-EN.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ const translationEnEN = {
helptext_boundingbox: "Add Bounding Box",
helptext_polypolygon: "Add Polygon",
helptext_circle: "Add Circle",
helptext_auto_bounding_box: "Auto Bouding Box Selection",
auto_bounding_box_processing: "Processing...",
auto_bounding_box_done: "Auto Bouding Box Selection Completed",
comment_placeholder: "Write comment here...",
image_tags: "Image Tags",
image_tags_classification_placeholder: "Image Classification",
Expand Down
4 changes: 4 additions & 0 deletions client/src/MainLayout/icon-dictionary.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
faEdit,
faGripLines,
faHandPaper,
faWandMagicSparkles,
faMousePointer,
faSearch,
faCircleDot,
Expand Down Expand Up @@ -67,6 +68,9 @@ export const iconDictionary = {
"modify-allowed-area": () => (
<FontAwesomeIcon style={faStyle} size="xs" fixedWidth icon={faEdit} />
),
"auto-annotate": () => (
<FontAwesomeIcon style={faStyle} size="xs" fixedWidth icon={faWandMagicSparkles}/>
),
"create-keypoints": AccessibilityNewIcon,
}

Expand Down
25 changes: 25 additions & 0 deletions client/src/MainLayout/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { Save, ExitToApp } from "@mui/icons-material"
import capitalize from "lodash/capitalize"
import { useTranslation } from "react-i18next"
import { useSnackbar } from "../SnackbarContext"
import { getAutoAnnotation } from "../utils/send-data-to-server"
import ClassDistributionSidebarBox from "../ClassDistributionSidebarBox"
import config from "../config"

Expand Down Expand Up @@ -166,9 +167,27 @@ export const MainLayout = ({
/>
)

const onAutoAnnotate = useEventCallback(async () => {
const image = state.images[state.selectedImage]
const imageName = decodeURIComponent(image.src.split("/").pop())

showSnackbar(t('auto_bounding_box_processing'), "info")
try {
const response = await getAutoAnnotation({image_name: imageName})
dispatch({ type: "AUTO_ANNOTATE_IMAGE", annotations: response})

showSnackbar(t('auto_bounding_box_done'), "success")
} catch (error) {
showSnackbar(error.message, "error")
}
}, []);

const onClickIconSidebarItem = useEventCallback((item) => {
const { selectedTool } = state
if (selectedTool.length > 0 && item.name !== null) {
if(item.name === "auto-annotate") {
onAutoAnnotate()
}
dispatch({ type: "SELECT_TOOL", selectedTool: item.name })
}
})
Expand Down Expand Up @@ -297,6 +316,12 @@ export const MainLayout = ({
helperText:
t("helptext_polypolygon") + getHotkeyHelpText("create_polygon"),
},
{
name: "auto-annotate",
alwaysShowing: config.SAM_MODEL_ENABLED,
helperText:t("helptext_auto_bounding_box") + getHotkeyHelpText("auto_annotate"),
},

{
name: "create-line",
helperText: "Add Line",
Expand Down
7 changes: 7 additions & 0 deletions client/src/ShortcutsManager/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ export const useDispatchHotkeyHandlers = ({ dispatch }) => {
selectedTool: "create-polygon",
})
},

auto_annotate: () => {
dispatch({
type: "SELECT_TOOL",
selectedTool: "auto-annotate",
})
},
create_pixel: () => {
dispatch({
type: "SELECT_TOOL",
Expand Down
2 changes: 1 addition & 1 deletion client/src/SnackbarContext/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export const SnackbarProvider = ({ children }) => {
<Snackbar
key={msg.id + index}
open
autoHideDuration={6000}
autoHideDuration={null}
anchorOrigin={{ vertical: "bottom", horizontal: "left" }}
onClose={() => handleClose(index)}
sx={{ mb: index * 8 }}
Expand Down
1 change: 1 addition & 0 deletions client/src/config.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const config = {
BOUNDING_BOX: 2,
},
SHOW_CLASS_DISTRIBUTION: true,
SAM_MODEL_ENABLED: false
}

export default config
13 changes: 13 additions & 0 deletions client/src/utils/send-data-to-server.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ export const getImagesAnnotation = (imageData) => {
})
}

export const getAutoAnnotation = (imageData) => {
return new Promise((resolve, reject) => {
axios
.post(`${config.SERVER_URL}/get_auto_annotations`, imageData)
.then((response) => {
resolve(response.data)
})
.catch((error) => {
reject(error.response.data)
})
})
}

export const saveSettings = (settings) => {
return new Promise((resolve, reject) => {
axios
Expand Down
Binary file added sample_images/sam_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 51 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@
import shutil
import zipfile
import math
from sam_model import SamModel
from utils import load_image_from_url, format_regions_for_frontend

app = Flask(__name__)
app.config.from_object("config")

# URL of the sam_model to download
if app.config['SAM_MODEL_ENABLED']:
from sam_model import SamModel
# URL of the sam_model to download
model_type = 'vit_h'
model_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
model_path = 'sam_model.pth' # Path to save the model

sam_model = SamModel(model_url, model_path, model_type)
else:
sam_model = None

# Get the CLIENT_URL environment variable, set a default to 80
client_url = os.getenv("CLIENT_URL", "http://localhost")
Expand Down Expand Up @@ -147,6 +160,43 @@ def update_settings():
return jsonify({"message": "Settings updated successfully"})


@app.route("/get_auto_annotations", methods=["POST"])
@cross_origin(origin="*", headers=["Content-Type"])
def get_auto_annotations():
try:
data = request.get_json()
image_name = data.get("image_name")
if not image_name:
raise ValueError("Invalid JSON data format: 'image_name' not found.")

image_annotations = []
base_url = request.host_url + "uploads/"
image_url = base_url + image_name
image = load_image_from_url(image_url)
regions = sam_model.predict(image)
formatted_regions = format_regions_for_frontend(regions, image_url, image.shape[1], image.shape[0])
image_annotations.append(
{
"image_name": image_name,
"image_source": image_url,
"regions": formatted_regions,
}
)
return jsonify(image_annotations), 200

except ValueError as ve:
print("ValueError:", ve)
traceback.print_exc()
return jsonify({"error": str(ve)}), 400
except requests.exceptions.RequestException as re:
print("RequestException:", re)
traceback.print_exc()
return jsonify({"error": "Error fetching image from URL"}), 500
except Exception as e:
print("General error:", e)
traceback.print_exc()
return jsonify({"error": str(e)}), 500

@app.route("/settings/reset", methods=["POST"])
@cross_origin(origin=client_url, headers=["Content-Type"])
def reset_settings():
Expand Down Expand Up @@ -1196,4 +1246,5 @@ def main():

# If the file is run directly,start the app.
if __name__ == "__main__":
print("Starting server...")
app.run(debug=False)
1 change: 1 addition & 0 deletions server/config.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
MASK_BACKGROUND_COLOR = (0, 0, 0)
SAM_MODEL_ENABLED = False
4 changes: 4 additions & 0 deletions server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
git+https://github.com/facebookresearch/segment-anything.git
torch
torchvision
opencv-python
supervision
flask-cors
pandas
pillow
Expand Down
55 changes: 55 additions & 0 deletions server/sam_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import os
import requests
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import numpy as np
import supervision as sv
import cv2

class SamModel:
def __init__(self, model_url, model_path, model_type):
self.model_path = model_path
if not os.path.exists(self.model_path):
print(f"Downloading model from {model_url}")
self.download_model(model_url, self.model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = sam_model_registry[model_type](checkpoint=self.model_path)
self.model.to(self.device)
self.model.eval()

def download_model(self, url, path):
response = requests.get(url)
response.raise_for_status() # Ensure we notice bad responses
with open(path, 'wb') as f:
f.write(response.content)
print(f"Model downloaded to {path}")

def load_model(self, model_path):
model = SamPredictor()
model.load_state_dict(torch.load(model_path, map_location=self.device))
return model

def get_annotations(self, masks):
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks),
mask=masks
)

detections = detections[detections.area == np.max(detections.area)]
return detections

def predict(self, image):
mask_generator = SamAutomaticMaskGenerator(
self.model,
points_per_side=24,
pred_iou_thresh=0.9,
stability_score_thresh=0.95,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=200)

image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sam_result = mask_generator.generate(image_rgb)
detections = sv.Detections.from_sam(sam_result=sam_result)
return detections

31 changes: 31 additions & 0 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import requests
import cv2
import numpy as np
import uuid

def load_image_from_url(url):
response = requests.get(url)
image = np.asarray(bytearray(response.content), dtype="uint8")
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
return image

def format_regions_for_frontend(detections, image_src, image_width, image_height):
# Assuming mask contains 'segmentation' key with binary mask data
formatted_regions = []
for i in range(len(detections.xyxy)):
x1, y1, x2, y2 = detections.xyxy[i]
region = {
"cls": "",
"comment": "",
"color": "#f44336",
"h": (y2 - y1) / image_height,
"id": uuid.uuid4(), # Generate unique ID
"image-src": image_src,
"tags": "",
"type": "box",
"w": (x2 - x1) / image_width,
"x": x1 / image_width,
"y": y1 / image_height
}
formatted_regions.append(region)
return formatted_regions
Loading