Skip to content

[WIP] Update RKDE #2441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions configs/model/rkde.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,3 @@ model:
n_pca_components: 16
feature_scaling_method: SCALE
max_training_points: 40000

task: detection
33 changes: 13 additions & 20 deletions src/anomalib/models/image/rkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from torchvision.transforms.v2 import Compose, Resize, Transform

from anomalib import LearningType
from anomalib.data import Batch, InferenceBatch
from anomalib.metrics import Evaluator
from anomalib.models.components import AnomalibModule, MemoryBankMixin
from anomalib.models.components.classification import FeatureScalingMethod
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor

from .region_extractor import RoiStage
from .torch_model import RkdeModel

logger = logging.getLogger(__name__)
Expand All @@ -30,11 +30,9 @@ class Rkde(MemoryBankMixin, AnomalibModule):
"""Region Based Anomaly Detection With Real-Time Training and Analysis.

Args:
roi_stage (RoiStage, optional): Processing stage from which rois are extracted.
Defaults to ``RoiStage.RCNN``.
roi_score_threshold (float, optional): Mimumum confidence score for the region proposals.
roi_score_threshold (float, optional): Minimum confidence score for the region proposals.
Defaults to ``0.001``.
min_size (int, optional): Minimum size in pixels for the region proposals.
min_box_size (int, optional): Minimum size in pixels for the region proposals.
Defaults to ``25``.
iou_threshold (float, optional): Intersection-Over-Union threshold used during NMS.
Defaults to ``0.3``.
Expand All @@ -55,7 +53,6 @@ class Rkde(MemoryBankMixin, AnomalibModule):

def __init__(
self,
roi_stage: RoiStage = RoiStage.RCNN,
roi_score_threshold: float = 0.001,
min_box_size: int = 25,
iou_threshold: float = 0.3,
Expand All @@ -70,7 +67,6 @@ def __init__(
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)

self.model: RkdeModel = RkdeModel(
roi_stage=roi_stage,
roi_score_threshold=roi_score_threshold,
min_box_size=min_box_size,
iou_threshold=iou_threshold,
Expand All @@ -86,11 +82,11 @@ def configure_optimizers() -> None:
"""RKDE doesn't require optimization, therefore returns no optimizers."""
return

def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None:
def training_step(self, batch: Batch, *args, **kwargs) -> None:
"""Perform a training Step of RKDE. For each batch, features are extracted from the CNN.

Args:
batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask
batch (Batch): Batch containing image filename, image, label and mask
args: Additional arguments.
kwargs: Additional keyword arguments.

Expand All @@ -99,7 +95,7 @@ def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -
"""
del args, kwargs # These variables are not used.

features = self.model(batch["image"])
features = self.model(batch.image)
self.embeddings.append(features)

def fit(self) -> None:
Expand All @@ -109,13 +105,13 @@ def fit(self) -> None:
logger.info("Fitting a KDE model to the embedding collected from the training set.")
self.model.fit(embeddings)

def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
"""Perform a validation Step of RKde.

Similar to the training step, features are extracted from the CNN for each batch.

Args:
batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask
batch (Batch): Batch containing image filename, image, label and mask
args: Additional arguments.
kwargs: Additional keyword arguments.

Expand All @@ -125,15 +121,12 @@ def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs)
del args, kwargs # These variables are not used.

# get batched model predictions
boxes, scores = self.model(batch["image"])

# convert batched predictions to list format
image: torch.Tensor = batch["image"]
batch_size = image.shape[0]
indices = boxes[:, 0]
batch["pred_boxes"] = [boxes[indices == i, 1:] for i in range(batch_size)]
batch["box_scores"] = [scores[indices == i] for i in range(batch_size)]
predictions: InferenceBatch = self.model(batch.image)

batch.update(
pred_score=predictions.pred_score,
anomaly_map=predictions.anomaly_map,
)
return batch

@property
Expand Down
116 changes: 42 additions & 74 deletions src/anomalib/models/image/rkde/region_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,20 @@
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from enum import Enum

import torch
from torch import nn
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.ops import boxes as box_ops

from anomalib.data.utils.boxes import scale_boxes


class RoiStage(str, Enum):
"""Processing stage from which rois are extracted."""

RCNN = "rcnn"
RPN = "rpn"
from torchvision.models.detection import (
MaskRCNN_ResNet50_FPN_V2_Weights,
maskrcnn_resnet50_fpn_v2,
)
from torchvision.ops import box_area, nms


class RegionExtractor(nn.Module):
"""Extracts regions from the image.

Args:
stage (RoiStage, optional): Processing stage from which rois are extracted.
Defaults to ``RoiStage.RCNN``.
score_threshold (float, optional): Mimumum confidence score for the region proposals.
score_threshold (float, optional): Minimum confidence score for the region proposals.
Defaults to ``0.001``.
min_size (int, optional): Minimum size in pixels for the region proposals.
Defaults to ``25``.
Expand All @@ -41,30 +31,24 @@ class RegionExtractor(nn.Module):

def __init__(
self,
stage: RoiStage = RoiStage.RCNN,
score_threshold: float = 0.001,
min_size: int = 25,
iou_threshold: float = 0.3,
max_detections_per_image: int = 100,
) -> None:
super().__init__()

# Affects global behaviour of the region extractor
self.stage = stage
self.min_size = min_size
self.iou_threshold = iou_threshold
self.max_detections_per_image = max_detections_per_image

# Affects behaviour depending on roi stage
rpn_top_n = max_detections_per_image if self.stage == RoiStage.RPN else 1000
rpn_score_thresh = score_threshold if self.stage == RoiStage.RPN else 0.0

# Create the model
self.faster_rcnn = fasterrcnn_resnet50_fpn(
pretrained=True,
rpn_post_nms_top_n_test=rpn_top_n,
rpn_score_thresh=rpn_score_thresh,

self.backbone = maskrcnn_resnet50_fpn_v2(
weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT,
rpn_post_nms_top_n_test=1000,
rpn_score_thresh=0.0,
box_score_thresh=score_threshold,
rpn_nms_thresh=0.3,
box_nms_thresh=1.0, # this disables nms (we apply custom label-agnostic nms during post-processing)
box_detections_per_img=1000, # this disables filtering top-k predictions (we apply our own after nms)
)
Expand All @@ -80,66 +64,50 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor:
ValueError: When ``stage`` is not one of ``rcnn`` or ``rpn``.

Returns:
Tensor: Predicted regions, tensor of shape [N, 5] where N is the number of predicted regions in the batch,
and where each row describes the index of the image in the batch and the 4 bounding box coordinates.
List of dictionaries containing the processed box predictions.
The dictionary has the keys ``boxes``, ``masks``, ``scores``, and ``labels``.
"""
if self.training:
msg = "Should not be in training mode"
raise ValueError(msg)

if self.stage == RoiStage.RCNN:
# get rois from rcnn output
predictions = self.faster_rcnn(batch)
all_regions = [prediction["boxes"] for prediction in predictions]
all_scores = [prediction["scores"] for prediction in predictions]
elif self.stage == RoiStage.RPN:
# get rois from region proposal network
images, _ = self.faster_rcnn.transform(batch)
features = self.faster_rcnn.backbone(images.tensors)
proposals, _ = self.faster_rcnn.rpn(images, features)
# post-process raw rpn predictions
all_regions = [box_ops.clip_boxes_to_image(boxes, images.tensors.shape[-2:]) for boxes in proposals]
all_regions = [scale_boxes(boxes, images.tensors.shape[-2:], batch.shape[-2:]) for boxes in all_regions]
all_scores = [torch.ones(boxes.shape[0]).to(boxes.device) for boxes in all_regions]
else:
msg = f"Unknown region extractor stage: {self.stage}"
raise ValueError(msg)
regions: list[dict[str, torch.Tensor]] = self.backbone(batch)
return self.post_process_box_predictions(regions)

regions = self.post_process_box_predictions(all_regions, all_scores)

# convert from list of [N, 4] tensors to single [N, 5] tensor where each row is [index-in-batch, x1, y1, x2, y2]
indices = torch.repeat_interleave(
torch.arange(len(regions)),
torch.Tensor([rois.shape[0] for rois in regions]).int(),
)
return torch.cat([indices.unsqueeze(1).to(batch.device), torch.cat(regions)], dim=1)

def post_process_box_predictions(self, pred_boxes: torch.Tensor, pred_scores: torch.Tensor) -> list[torch.Tensor]:
def post_process_box_predictions(self, regions: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]:
"""Post-processes the box predictions.

The post-processing consists of removing small boxes, applying nms, and
keeping only the k boxes with the highest confidence score.

Args:
pred_boxes (torch.Tensor): Box predictions of shape (N, 4).
pred_scores (torch.Tensor): torch.Tensor of shape () with a confidence score for each box prediction.
regions (list[dict[str, torch.Tensor]]): List of dictionaries containing the box predictions.

Returns:
list[torch.Tensor]: Post-processed box predictions of shape (N, 4).
list[dict[str, torch.Tensor]]: List of dictionaries containing the processed box predictions.
"""
processed_boxes_list: list[torch.Tensor] = []
for boxes, scores in zip(pred_boxes, pred_scores, strict=True):
new_regions = []
for _region in regions:
boxes = _region["boxes"]
masks = _region["masks"]
scores = _region["scores"]
labels = _region["labels"]
# remove small boxes
keep = box_ops.remove_small_boxes(boxes, min_size=self.min_size)
processed_boxes, processed_scores = boxes[keep], scores[keep]

# non-maximum suppression, all boxes together
keep = box_ops.nms(processed_boxes, processed_scores, self.iou_threshold)

# keep only top-k scoring predictions
keep = torch.where(box_area(boxes) > self.min_size)
boxes = boxes[keep]
masks = masks[keep]
scores = scores[keep]
labels = labels[keep]
# # non-maximum suppression
keep = nms(boxes, scores, self.iou_threshold)
# # keep only top-k scoring predictions
keep = keep[: self.max_detections_per_image]
processed_boxes = processed_boxes[keep]

processed_boxes_list.append(processed_boxes)

return processed_boxes_list
processed_boxes = {
"boxes": boxes[keep],
"masks": masks[keep],
"scores": scores[keep],
"labels": labels[keep],
}
new_regions.append(processed_boxes)

return new_regions
50 changes: 39 additions & 11 deletions src/anomalib/models/image/rkde/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import torch
from torch import nn

from anomalib.data import InferenceBatch
from anomalib.models.components.classification import FeatureScalingMethod, KDEClassifier

from .feature_extractor import FeatureExtractor
from .region_extractor import RegionExtractor, RoiStage
from .region_extractor import RegionExtractor

logger = logging.getLogger(__name__)

Expand All @@ -20,11 +21,9 @@ class RkdeModel(nn.Module):
"""Torch Model for the Region-based Anomaly Detection Model.

Args:
roi_stage (RoiStage, optional): Processing stage from which rois are extracted.
Defaults to ``RoiStage.RCNN``.
roi_score_threshold (float, optional): Mimumum confidence score for the region proposals.
roi_score_threshold (float, optional): Minimum confidence score for the region proposals.
Defaults to ``0.001``.
min_size (int, optional): Minimum size in pixels for the region proposals.
min_box_size (int, optional): Minimum size in pixels for the region proposals.
Defaults to ``25``.
iou_threshold (float, optional): Intersection-Over-Union threshold used during NMS.
Defaults to ``0.3``.
Expand All @@ -43,7 +42,6 @@ class RkdeModel(nn.Module):
def __init__(
self,
# roi params
roi_stage: RoiStage = RoiStage.RCNN,
roi_score_threshold: float = 0.001,
min_box_size: int = 25,
iou_threshold: float = 0.3,
Expand All @@ -56,7 +54,6 @@ def __init__(
super().__init__()

self.region_extractor = RegionExtractor(
stage=roi_stage,
score_threshold=roi_score_threshold,
min_size=min_box_size,
iou_threshold=iou_threshold,
Expand All @@ -82,7 +79,7 @@ def fit(self, embeddings: torch.Tensor) -> bool:
"""
return self.classifier.fit(embeddings)

def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch:
"""Prediction by normality model.

Args:
Expand All @@ -96,19 +93,50 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, tor
self.feature_extractor.eval()

# 1. apply region extraction
rois = self.region_extractor(batch)
with torch.no_grad():
regions: list[dict[str, torch.Tensor]] = self.region_extractor(batch)

# convert from list of [N, 4] tensors to single [N, 5] tensor where each row is [index-in-batch, x1, y1, x2, y2]
boxes_list = [batch_item["boxes"] for batch_item in regions]
indices = torch.repeat_interleave(
torch.arange(len(regions)),
torch.Tensor([boxes.shape[0] for boxes in boxes_list]).int(),
)
rois = torch.cat([indices.unsqueeze(1).to(batch.device), torch.cat(boxes_list)], dim=1)

# 2. apply feature extraction
if rois.shape[0] == 0:
# cannot extract features when no rois are retrieved
features = torch.empty((0, 4096)).to(batch.device)
else:
features = self.feature_extractor(batch, rois.clone())
with torch.no_grad():
features = self.feature_extractor(batch, rois)

if self.training:
return features

# 3. apply density estimation
scores = self.classifier(features)

return rois, scores
# 4. Compute anomaly map
masks = torch.cat([region["masks"] for region in regions])
# Select the mask with the highest score for each region
anomaly_map = torch.stack(
[
torch.amax(masks[indices == i] * scores[indices == i].view(-1, 1, 1, 1), dim=0)
if i in indices
else torch.zeros_like(masks[0])
for i in range(len(regions))
],
)

# 5. Compute box scores
pred_scores = torch.stack([
torch.amax(scores[indices == i]) if i in indices else torch.tensor(0.0, device=scores[0].device)
for i in range(len(regions))
])

return InferenceBatch(
pred_score=pred_scores,
anomaly_map=anomaly_map,
)