diff --git a/configs/model/rkde.yaml b/configs/model/rkde.yaml index b421ffed7d..36ca63c261 100644 --- a/configs/model/rkde.yaml +++ b/configs/model/rkde.yaml @@ -9,5 +9,3 @@ model: n_pca_components: 16 feature_scaling_method: SCALE max_training_points: 40000 - -task: detection diff --git a/src/anomalib/models/image/rkde/lightning_model.py b/src/anomalib/models/image/rkde/lightning_model.py index 20a18496fc..f20f65288d 100644 --- a/src/anomalib/models/image/rkde/lightning_model.py +++ b/src/anomalib/models/image/rkde/lightning_model.py @@ -14,13 +14,13 @@ from torchvision.transforms.v2 import Compose, Resize, Transform from anomalib import LearningType +from anomalib.data import Batch, InferenceBatch from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule, MemoryBankMixin from anomalib.models.components.classification import FeatureScalingMethod from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor -from .region_extractor import RoiStage from .torch_model import RkdeModel logger = logging.getLogger(__name__) @@ -30,11 +30,9 @@ class Rkde(MemoryBankMixin, AnomalibModule): """Region Based Anomaly Detection With Real-Time Training and Analysis. Args: - roi_stage (RoiStage, optional): Processing stage from which rois are extracted. - Defaults to ``RoiStage.RCNN``. - roi_score_threshold (float, optional): Mimumum confidence score for the region proposals. + roi_score_threshold (float, optional): Minimum confidence score for the region proposals. Defaults to ``0.001``. - min_size (int, optional): Minimum size in pixels for the region proposals. + min_box_size (int, optional): Minimum size in pixels for the region proposals. Defaults to ``25``. iou_threshold (float, optional): Intersection-Over-Union threshold used during NMS. Defaults to ``0.3``. @@ -55,7 +53,6 @@ class Rkde(MemoryBankMixin, AnomalibModule): def __init__( self, - roi_stage: RoiStage = RoiStage.RCNN, roi_score_threshold: float = 0.001, min_box_size: int = 25, iou_threshold: float = 0.3, @@ -70,7 +67,6 @@ def __init__( super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) self.model: RkdeModel = RkdeModel( - roi_stage=roi_stage, roi_score_threshold=roi_score_threshold, min_box_size=min_box_size, iou_threshold=iou_threshold, @@ -86,11 +82,11 @@ def configure_optimizers() -> None: """RKDE doesn't require optimization, therefore returns no optimizers.""" return - def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None: + def training_step(self, batch: Batch, *args, **kwargs) -> None: """Perform a training Step of RKDE. For each batch, features are extracted from the CNN. Args: - batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + batch (Batch): Batch containing image filename, image, label and mask args: Additional arguments. kwargs: Additional keyword arguments. @@ -99,7 +95,7 @@ def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) - """ del args, kwargs # These variables are not used. - features = self.model(batch["image"]) + features = self.model(batch.image) self.embeddings.append(features) def fit(self) -> None: @@ -109,13 +105,13 @@ def fit(self) -> None: logger.info("Fitting a KDE model to the embedding collected from the training set.") self.model.fit(embeddings) - def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a validation Step of RKde. Similar to the training step, features are extracted from the CNN for each batch. Args: - batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask + batch (Batch): Batch containing image filename, image, label and mask args: Additional arguments. kwargs: Additional keyword arguments. @@ -125,15 +121,12 @@ def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) del args, kwargs # These variables are not used. # get batched model predictions - boxes, scores = self.model(batch["image"]) - - # convert batched predictions to list format - image: torch.Tensor = batch["image"] - batch_size = image.shape[0] - indices = boxes[:, 0] - batch["pred_boxes"] = [boxes[indices == i, 1:] for i in range(batch_size)] - batch["box_scores"] = [scores[indices == i] for i in range(batch_size)] + predictions: InferenceBatch = self.model(batch.image) + batch.update( + pred_score=predictions.pred_score, + anomaly_map=predictions.anomaly_map, + ) return batch @property diff --git a/src/anomalib/models/image/rkde/region_extractor.py b/src/anomalib/models/image/rkde/region_extractor.py index 8471ec4edb..4dad3844b6 100644 --- a/src/anomalib/models/image/rkde/region_extractor.py +++ b/src/anomalib/models/image/rkde/region_extractor.py @@ -6,30 +6,20 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from enum import Enum - import torch from torch import nn -from torchvision.models.detection import fasterrcnn_resnet50_fpn -from torchvision.ops import boxes as box_ops - -from anomalib.data.utils.boxes import scale_boxes - - -class RoiStage(str, Enum): - """Processing stage from which rois are extracted.""" - - RCNN = "rcnn" - RPN = "rpn" +from torchvision.models.detection import ( + MaskRCNN_ResNet50_FPN_V2_Weights, + maskrcnn_resnet50_fpn_v2, +) +from torchvision.ops import box_area, nms class RegionExtractor(nn.Module): """Extracts regions from the image. Args: - stage (RoiStage, optional): Processing stage from which rois are extracted. - Defaults to ``RoiStage.RCNN``. - score_threshold (float, optional): Mimumum confidence score for the region proposals. + score_threshold (float, optional): Minimum confidence score for the region proposals. Defaults to ``0.001``. min_size (int, optional): Minimum size in pixels for the region proposals. Defaults to ``25``. @@ -41,30 +31,24 @@ class RegionExtractor(nn.Module): def __init__( self, - stage: RoiStage = RoiStage.RCNN, score_threshold: float = 0.001, min_size: int = 25, iou_threshold: float = 0.3, max_detections_per_image: int = 100, ) -> None: super().__init__() - - # Affects global behaviour of the region extractor - self.stage = stage self.min_size = min_size self.iou_threshold = iou_threshold self.max_detections_per_image = max_detections_per_image # Affects behaviour depending on roi stage - rpn_top_n = max_detections_per_image if self.stage == RoiStage.RPN else 1000 - rpn_score_thresh = score_threshold if self.stage == RoiStage.RPN else 0.0 - - # Create the model - self.faster_rcnn = fasterrcnn_resnet50_fpn( - pretrained=True, - rpn_post_nms_top_n_test=rpn_top_n, - rpn_score_thresh=rpn_score_thresh, + + self.backbone = maskrcnn_resnet50_fpn_v2( + weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT, + rpn_post_nms_top_n_test=1000, + rpn_score_thresh=0.0, box_score_thresh=score_threshold, + rpn_nms_thresh=0.3, box_nms_thresh=1.0, # this disables nms (we apply custom label-agnostic nms during post-processing) box_detections_per_img=1000, # this disables filtering top-k predictions (we apply our own after nms) ) @@ -80,66 +64,50 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: ValueError: When ``stage`` is not one of ``rcnn`` or ``rpn``. Returns: - Tensor: Predicted regions, tensor of shape [N, 5] where N is the number of predicted regions in the batch, - and where each row describes the index of the image in the batch and the 4 bounding box coordinates. + List of dictionaries containing the processed box predictions. + The dictionary has the keys ``boxes``, ``masks``, ``scores``, and ``labels``. """ if self.training: msg = "Should not be in training mode" raise ValueError(msg) - if self.stage == RoiStage.RCNN: - # get rois from rcnn output - predictions = self.faster_rcnn(batch) - all_regions = [prediction["boxes"] for prediction in predictions] - all_scores = [prediction["scores"] for prediction in predictions] - elif self.stage == RoiStage.RPN: - # get rois from region proposal network - images, _ = self.faster_rcnn.transform(batch) - features = self.faster_rcnn.backbone(images.tensors) - proposals, _ = self.faster_rcnn.rpn(images, features) - # post-process raw rpn predictions - all_regions = [box_ops.clip_boxes_to_image(boxes, images.tensors.shape[-2:]) for boxes in proposals] - all_regions = [scale_boxes(boxes, images.tensors.shape[-2:], batch.shape[-2:]) for boxes in all_regions] - all_scores = [torch.ones(boxes.shape[0]).to(boxes.device) for boxes in all_regions] - else: - msg = f"Unknown region extractor stage: {self.stage}" - raise ValueError(msg) + regions: list[dict[str, torch.Tensor]] = self.backbone(batch) + return self.post_process_box_predictions(regions) - regions = self.post_process_box_predictions(all_regions, all_scores) - - # convert from list of [N, 4] tensors to single [N, 5] tensor where each row is [index-in-batch, x1, y1, x2, y2] - indices = torch.repeat_interleave( - torch.arange(len(regions)), - torch.Tensor([rois.shape[0] for rois in regions]).int(), - ) - return torch.cat([indices.unsqueeze(1).to(batch.device), torch.cat(regions)], dim=1) - - def post_process_box_predictions(self, pred_boxes: torch.Tensor, pred_scores: torch.Tensor) -> list[torch.Tensor]: + def post_process_box_predictions(self, regions: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]: """Post-processes the box predictions. The post-processing consists of removing small boxes, applying nms, and keeping only the k boxes with the highest confidence score. Args: - pred_boxes (torch.Tensor): Box predictions of shape (N, 4). - pred_scores (torch.Tensor): torch.Tensor of shape () with a confidence score for each box prediction. + regions (list[dict[str, torch.Tensor]]): List of dictionaries containing the box predictions. Returns: - list[torch.Tensor]: Post-processed box predictions of shape (N, 4). + list[dict[str, torch.Tensor]]: List of dictionaries containing the processed box predictions. """ - processed_boxes_list: list[torch.Tensor] = [] - for boxes, scores in zip(pred_boxes, pred_scores, strict=True): + new_regions = [] + for _region in regions: + boxes = _region["boxes"] + masks = _region["masks"] + scores = _region["scores"] + labels = _region["labels"] # remove small boxes - keep = box_ops.remove_small_boxes(boxes, min_size=self.min_size) - processed_boxes, processed_scores = boxes[keep], scores[keep] - - # non-maximum suppression, all boxes together - keep = box_ops.nms(processed_boxes, processed_scores, self.iou_threshold) - - # keep only top-k scoring predictions + keep = torch.where(box_area(boxes) > self.min_size) + boxes = boxes[keep] + masks = masks[keep] + scores = scores[keep] + labels = labels[keep] + # # non-maximum suppression + keep = nms(boxes, scores, self.iou_threshold) + # # keep only top-k scoring predictions keep = keep[: self.max_detections_per_image] - processed_boxes = processed_boxes[keep] - - processed_boxes_list.append(processed_boxes) - - return processed_boxes_list + processed_boxes = { + "boxes": boxes[keep], + "masks": masks[keep], + "scores": scores[keep], + "labels": labels[keep], + } + new_regions.append(processed_boxes) + + return new_regions diff --git a/src/anomalib/models/image/rkde/torch_model.py b/src/anomalib/models/image/rkde/torch_model.py index ee574bf1ac..2d1e475d8d 100644 --- a/src/anomalib/models/image/rkde/torch_model.py +++ b/src/anomalib/models/image/rkde/torch_model.py @@ -8,10 +8,11 @@ import torch from torch import nn +from anomalib.data import InferenceBatch from anomalib.models.components.classification import FeatureScalingMethod, KDEClassifier from .feature_extractor import FeatureExtractor -from .region_extractor import RegionExtractor, RoiStage +from .region_extractor import RegionExtractor logger = logging.getLogger(__name__) @@ -20,11 +21,9 @@ class RkdeModel(nn.Module): """Torch Model for the Region-based Anomaly Detection Model. Args: - roi_stage (RoiStage, optional): Processing stage from which rois are extracted. - Defaults to ``RoiStage.RCNN``. - roi_score_threshold (float, optional): Mimumum confidence score for the region proposals. + roi_score_threshold (float, optional): Minimum confidence score for the region proposals. Defaults to ``0.001``. - min_size (int, optional): Minimum size in pixels for the region proposals. + min_box_size (int, optional): Minimum size in pixels for the region proposals. Defaults to ``25``. iou_threshold (float, optional): Intersection-Over-Union threshold used during NMS. Defaults to ``0.3``. @@ -43,7 +42,6 @@ class RkdeModel(nn.Module): def __init__( self, # roi params - roi_stage: RoiStage = RoiStage.RCNN, roi_score_threshold: float = 0.001, min_box_size: int = 25, iou_threshold: float = 0.3, @@ -56,7 +54,6 @@ def __init__( super().__init__() self.region_extractor = RegionExtractor( - stage=roi_stage, score_threshold=roi_score_threshold, min_size=min_box_size, iou_threshold=iou_threshold, @@ -82,7 +79,7 @@ def fit(self, embeddings: torch.Tensor) -> bool: """ return self.classifier.fit(embeddings) - def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: """Prediction by normality model. Args: @@ -96,14 +93,24 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, tor self.feature_extractor.eval() # 1. apply region extraction - rois = self.region_extractor(batch) + with torch.no_grad(): + regions: list[dict[str, torch.Tensor]] = self.region_extractor(batch) + + # convert from list of [N, 4] tensors to single [N, 5] tensor where each row is [index-in-batch, x1, y1, x2, y2] + boxes_list = [batch_item["boxes"] for batch_item in regions] + indices = torch.repeat_interleave( + torch.arange(len(regions)), + torch.Tensor([boxes.shape[0] for boxes in boxes_list]).int(), + ) + rois = torch.cat([indices.unsqueeze(1).to(batch.device), torch.cat(boxes_list)], dim=1) # 2. apply feature extraction if rois.shape[0] == 0: # cannot extract features when no rois are retrieved features = torch.empty((0, 4096)).to(batch.device) else: - features = self.feature_extractor(batch, rois.clone()) + with torch.no_grad(): + features = self.feature_extractor(batch, rois) if self.training: return features @@ -111,4 +118,25 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, tor # 3. apply density estimation scores = self.classifier(features) - return rois, scores + # 4. Compute anomaly map + masks = torch.cat([region["masks"] for region in regions]) + # Select the mask with the highest score for each region + anomaly_map = torch.stack( + [ + torch.amax(masks[indices == i] * scores[indices == i].view(-1, 1, 1, 1), dim=0) + if i in indices + else torch.zeros_like(masks[0]) + for i in range(len(regions)) + ], + ) + + # 5. Compute box scores + pred_scores = torch.stack([ + torch.amax(scores[indices == i]) if i in indices else torch.tensor(0.0, device=scores[0].device) + for i in range(len(regions)) + ]) + + return InferenceBatch( + pred_score=pred_scores, + anomaly_map=anomaly_map, + )