diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index f11786f822..56c02d6109 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -66,6 +66,7 @@ Fastflow, Fre, Ganomaly, + Glass, Padim, Patchcore, ReverseDistillation, @@ -102,6 +103,7 @@ class UnknownModelError(ModuleNotFoundError): "Fastflow", "Fre", "Ganomaly", + "Glass", "Padim", "Patchcore", "ReverseDistillation", diff --git a/src/anomalib/models/components/feature_extractors/__init__.py b/src/anomalib/models/components/feature_extractors/__init__.py index 66a2f36c34..b9936e793d 100644 --- a/src/anomalib/models/components/feature_extractors/__init__.py +++ b/src/anomalib/models/components/feature_extractors/__init__.py @@ -28,7 +28,6 @@ from .timm import TimmFeatureExtractor from .utils import dryrun_find_featuremap_dims - __all__ = [ "dryrun_find_featuremap_dims", "TimmFeatureExtractor", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 2717290f3a..6da89cbb12 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -55,6 +55,7 @@ from .fastflow import Fastflow from .fre import Fre from .ganomaly import Ganomaly +from .glass import Glass from .padim import Padim from .patchcore import Patchcore from .reverse_distillation import ReverseDistillation @@ -76,6 +77,7 @@ "Fastflow", "Fre", "Ganomaly", + "Glass", "Padim", "Patchcore", "ReverseDistillation", @@ -83,5 +85,5 @@ "Supersimplenet", "Uflow", "VlmAd", - "WinClip", + "WinClip" ] diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py new file mode 100644 index 0000000000..a3070bacf4 --- /dev/null +++ b/src/anomalib/models/image/glass/__init__.py @@ -0,0 +1,23 @@ +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Glass + +__all__ = ["Glass"] diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py new file mode 100644 index 0000000000..99e82e489b --- /dev/null +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -0,0 +1,324 @@ +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim +from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize + +from anomalib import LearningType +from anomalib.data import Batch +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator +from anomalib.metrics import Evaluator +from anomalib.models.components import AnomalibModule +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer + +from .torch_model import GlassModel + + +class Glass(AnomalibModule): + """PyTorch Lightning Implementation of the GLASS Model. + + The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain + bias. + Global anomaly features are synthesized from adapted normal features using gradient ascent. + Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature + extractor and feature adaptor. + All three different features are passed to the discriminator trained using loss functions. + + Args: + input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the + input pipeline. + anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures + backbone (str, optional): Name of the CNN backbone used for feature extraction. + Defaults to `"resnet18"`. + pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before + adaptation. + Defaults to `1024`. + target_embed_dim (int, optional): Dimensionality of the target adapted features after projection. + Defaults to `1024`. + patchsize (int, optional): Size of the local patch used in feature aggregation (e.g., for neighborhood pooling). + Defaults to `3`. + patchstride (int, optional): Stride used when extracting patches for local feature aggregation. + Defaults to `1`. + pre_trained (bool, optional): Whether to use ImageNet pre-trained weights for the backbone network. + Defaults to `True`. + layers (list[str], optional): List of backbone layers to extract features from. + Defaults to `["layer1", "layer2", "layer3"]`. + pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before + discriminator). + Defaults to `1`. + dsc_layers (int, optional): Number of layers in the discriminator network. + Defaults to `2`. + dsc_hidden (int, optional): Number of hidden units in each discriminator layer. + Defaults to `1024`. + dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator + training. + Defaults to `0.5`. + pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing. + Set to `True` to apply default normalization and resizing. + Defaults to `True`. + post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output + smoothing or thresholding. + Defaults to `True`. + evaluator (Evaluator | bool, optional): Evaluation module for calculating metrics such as AUROC and PRO. + Defaults to `True`. + visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and + anomaly scores. + Defaults to `True`. + mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during + training. + Defaults to `1`. + noise (float, optional): Standard deviation of Gaussian noise used in feature-level anomaly synthesis. + Defaults to `0.015`. + radius (float, optional): Radius parameter used for truncated projection in the anomaly synthesis strategy. + Determines the range for valid synthetic anomalies in the hypersphere or manifold. + Defaults to `0.75`. + p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation + choice. + Defaults to `0.5`. + lr (float, optional): Learning rate for training the feature adaptor and discriminator networks. + Defaults to `0.0001`. + step (int, optional): Number of gradient ascent steps for anomaly synthesis. + Defaults to `20`. + svd (int, optional): Flag to enable SVD-based feature projection. + Defaults to `0`. + """ + + def __init__( + self, + input_shape: tuple[int, int], + anomaly_source_path: str, + backbone: str = "resnet18", + pretrain_embed_dim: int = 1024, + target_embed_dim: int = 1024, + patchsize: int = 3, + patchstride: int = 1, + pre_trained: bool = True, + layers: list[str] | None = None, + pre_proj: int = 1, + dsc_layers: int = 2, + dsc_hidden: int = 1024, + dsc_margin: float = 0.5, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, + mining: int = 1, + noise: float = 0.015, + radius: float = 0.75, + p: float = 0.5, + lr: float = 0.0001, + step: int = 20, + svd: int = 0, + ) -> None: + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) + + if layers is None: + layers = ["layer1", "layer2", "layer3"] + + self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) + + self.model = GlassModel( + input_shape=input_shape, + anomaly_source_path=anomaly_source_path, + pretrain_embed_dim=pretrain_embed_dim, + target_embed_dim=target_embed_dim, + backbone=backbone, + pre_trained=pre_trained, + patchsize=patchsize, + patchstride=patchstride, + layers=layers, + pre_proj=pre_proj, + dsc_layers=dsc_layers, + dsc_hidden=dsc_hidden, + dsc_margin=dsc_margin, + step=step, + svd=svd, + mining=mining, + noise=noise, + radius=radius, + p=p, + ) + + self.c = torch.tensor([1]) + self.lr = lr + + if pre_proj > 0: + self.proj_opt = optim.AdamW( + self.model.pre_projection.parameters(), + self.lr, + weight_decay=1e-5, + ) + else: + self.proj_opt = None + + if not pre_trained: + self.backbone_opt = optim.AdamW( + self.mosdel.forward_modules["feature_aggregator"].backbone.parameters(), + self.lr, + ) + else: + self.backbone_opt = None + + @classmethod + def configure_pre_processor( + cls, + image_size: tuple[int, int] | None = None, + center_crop_size: tuple[int, int] | None = None, + ) -> PreProcessor: + """Configure the default pre-processor for GLASS. + + If valid center_crop_size is provided, the pre-processor will + also perform center cropping, according to the paper. + + Args: + image_size (tuple[int, int] | None, optional): Target size for + resizing. Defaults to ``(256, 256)``. + center_crop_size (tuple[int, int] | None, optional): Size for center + cropping. Defaults to ``None``. + + Returns: + PreProcessor: Configured pre-processor instance. + + Raises: + ValueError: If at least one dimension of ``center_crop_size`` is larger + than correspondent ``image_size`` dimension. + + Example: + >>> pre_processor = Glass.configure_pre_processor( + ... image_size=(256, 256) + ... ) + >>> transformed_image = pre_processor(image) + """ + image_size = image_size or (256, 256) + + if center_crop_size is not None: + if center_crop_size[0] > image_size[0] or center_crop_size[1] > image_size[1]: + msg = f"Center crop size {center_crop_size} cannot be larger than image size {image_size}." + raise ValueError(msg) + transform = Compose([ + Resize(image_size, antialias=True), + CenterCrop(center_crop_size), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + else: + transform = Compose([ + Resize(image_size, antialias=True), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + return PreProcessor(transform=transform) + + def configure_optimizers(self) -> optim.Optimizer: + """Configure optimizer for the discriminator. + + Returns: + Optimizer: AdamW Optimizer for the discriminator. + """ + return optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2) + + def training_step( + self, + batch: Batch, + batch_idx: int, + ) -> STEP_OUTPUT: + """Training step for GLASS model. + + Args: + batch (Batch): Input batch containing images and metadata + batch_idx (int): Index of the current batch + + Returns: + STEP_OUTPUT: Dictionary containing loss values and metrics + """ + del batch_idx + dsc_opt = self.optimizers() + + self.model.forward_modules.eval() + if self.model.pre_proj > 0: + self.model.pre_projection.train() + self.model.discriminator.train() + + dsc_opt.zero_grad() + if self.proj_opt is not None: + self.proj_opt.zero_grad() + if self.backbone_opt is not None: + self.backbone_opt.zero_grad() + + img = batch.image + true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c) + loss.backward() + + if self.proj_opt is not None: + self.proj_opt.step() + if self.backbone_opt is not None: + self.backbone_opt.step() + dsc_opt.step() + + self.log("true_loss", true_loss, prog_bar=True) + self.log("gaus_loss", gaus_loss, prog_bar=True) + self.log("bce_loss", bce_loss, prog_bar=True) + self.log("focal_loss", focal_loss, prog_bar=True) + self.log("loss", loss, prog_bar=True) + + def on_train_start(self) -> None: + """Initialize model by computing mean feature representation across training dataset. + + This method is called at the start of training and computes a mean feature vector + that serves as a reference point for the normal class distribution. + """ + dataloader = self.trainer.train_dataloader + + with torch.no_grad(): + for i, batch in enumerate(dataloader): + if i == 0: + self.c = self.model.calculate_mean(batch.image.to(self.device)) + else: + self.c += self.model.calculate_mean(batch.image.to(self.device)) + + self.c /= len(dataloader) + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type (ONE_CLASS for GLASS) + """ + return LearningType.ONE_CLASS + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return GLASS trainer arguments. + + Returns: + dict[str, Any]: Dictionary containing trainer configuration + """ + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py new file mode 100644 index 0000000000..db3c9da5cb --- /dev/null +++ b/src/anomalib/models/image/glass/loss.py @@ -0,0 +1,169 @@ +"""Focal Loss for multi-class classification with optional label smoothing and class weighting. + +This loss function is designed to address class imbalance by down-weighting easy examples and focusing training +on hard, misclassified examples. It is based on the paper: +"Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002). + +The focal loss formula is: + FL(pt) = -alpha * (1 - pt) ** gamma * log(pt) + +where: + - pt is the predicted probability of the correct class + - alpha is a class balancing factor + - gamma is a focusing parameter + +Supports optional label smoothing and flexible alpha input (scalar or per-class tensor). Can be used with raw logits, +applying a specified non-linearity (e.g., softmax or sigmoid). + +Args: + apply_nonlin (nn.Module or None): Optional non-linearity to apply to the logits before loss computation. + For example, use `nn.Softmax(dim=1)` or `nn.Sigmoid()` if logits are not normalized. + alpha (float or torch.Tensor, optional): Class balancing factor. Can be: + - None: Equal weighting for all classes. + - float: Scalar for binary class weighting; applied to `balance_index`. + - Tensor: Per-class weights of shape (num_classes,). + gamma (float): Focusing parameter (> 0) to reduce the loss contribution from easy examples. Default is 2. + balance_index (int): Index of the class to apply `alpha` to when `alpha` is a float. + smooth (float): Label smoothing factor. A small value (e.g., 1e-5) helps prevent overconfidence. + size_average (bool): If True, average the loss over the batch; if False, sum the loss. + +Raises: + ValueError: If `smooth` is outside the range [0, 1]. + TypeError: If `alpha` is not a supported type. + +Inputs: + logit (torch.Tensor): Raw model outputs (logits) of shape (B, C, ...) where B is batch size and C is number of + classes. + target (torch.Tensor): Ground-truth class indices of shape (B, 1, ...) or broadcastable to match logit. + +Returns: + torch.Tensor: Scalar loss value (averaged or summed based on `size_average`). +""" + +# Original Code +# Copyright (c) 2021 @Hsuxu +# https://github.com/Hsuxu/Loss_ToolBox-PyTorch. +# SPDX-License-Identifier: Apache-2.0 +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch +from torch import nn + + +class FocalLoss(nn.Module): + """Implementation of Focal Loss with support for smoothed label cross-entropy. + + As proposed in 'Focal Loss for Dense Object Detection' (https://arxiv.org/abs/1708.02002). + The focal loss formula is: + Focal_Loss = -1 * alpha * (1 - pt) ** gamma * log(pt) + + Args: + num_class (int): Number of classes. + alpha (float or Tensor): Scalar or Tensor weight factor for class imbalance. If float, `balance_index` should be + set. + gamma (float): Focusing parameter that reduces the relative loss for well-classified examples (gamma > 0). + smooth (float): Label smoothing factor for cross-entropy. + balance_index (int): Index of the class to balance when `alpha` is a float. + size_average (bool, optional): If True (default), the loss is averaged over the batch; otherwise, the loss is + summed. + """ + + def __init__( + self, + apply_nonlin: nn.Module | None = None, + alpha: float | torch.Tensor = None, + gamma: float = 2, + balance_index: int = 0, + smooth: float = 1e-5, + size_average: bool = True, + ) -> None: + """Initializes the FocalLoss instance. + + Args: + apply_nonlin (nn.Module or None): Optional non-linearity to apply to logits (e.g., softmax or sigmoid). + alpha (float or torch.Tensor, optional): Weighting factor for class imbalance. Can be: + - None: Equal weighting. + - float: Class at `balance_index` is weighted by `alpha`, others by 1 - `alpha`. + - Tensor: Direct per-class weights. + gamma (float): Focusing parameter for down-weighting easy examples (y > 0). + balance_index (int): Index of the class to apply `alpha` to when `alpha` is a float. + smooth (float): Label smoothing factor (0 to 1). + size_average (bool): If True, average the loss over the batch. If False, sum the loss. + """ + super().__init__() + self.apply_nonlin = apply_nonlin + self.alpha = alpha + self.gamma = gamma + self.balance_index = balance_index + self.smooth = smooth + self.size_average = size_average + + if self.smooth is not None and (self.smooth < 0 or self.smooth > 1.0): + msg = "smooth value should be in [0,1]" + raise ValueError(msg) + + def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Computes the focal loss between `logit` predictions and ground-truth `target`. + + Args: + logit (torch.Tensor): The predicted logits of shape (B, C, ...) where B is batch size and C is the number of + classes. + target (torch.Tensor): The ground-truth class indices of shape (B, 1, ...) or broadcastable to logit. + + Returns: + torch.Tensor: Computed focal loss value (averaged or summed depending on `size_average`). + """ + if self.apply_nonlin is not None: + logit = self.apply_nonlin(logit) + num_class = logit.shape[1] + + if logit.dim() > 2: + logit = logit.view(logit.size(0), logit.size(1), -1) + logit = logit.permute(0, 2, 1).contiguous() + logit = logit.view(-1, logit.size(-1)) + target = torch.squeeze(target, 1) + target = target.view(-1, 1) + + alpha = self.alpha + if alpha is None: + alpha = torch.ones(num_class, 1) + elif isinstance(alpha, (list | np.ndarray)): + assert len(alpha) == num_class + alpha = torch.FloatTensor(alpha).view(num_class, 1) + alpha = alpha / alpha.sum() + elif isinstance(alpha, float): + alpha = torch.ones(num_class, 1) + alpha = alpha * (1 - self.alpha) + alpha[self.balance_index] = self.alpha + else: + msg = "Not support alpha type" + raise TypeError(msg) + + if alpha.device != logit.device: + alpha = alpha.to(logit.device) + + idx = target.cpu().long() + one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() + one_hot_key = one_hot_key.scatter_(1, idx, 1) + if one_hot_key.device != logit.device: + one_hot_key = one_hot_key.to(logit.device) + + if self.smooth: + one_hot_key = torch.clamp( + one_hot_key, + self.smooth / (num_class - 1), + 1.0 - self.smooth, + ) + pt = (one_hot_key * logit).sum(1) + self.smooth + logpt = pt.log() + + gamma = self.gamma + alpha = alpha[idx] + alpha = torch.squeeze(alpha) + loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt + + return loss.mean() if self.size_average else loss.sum() diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py new file mode 100644 index 0000000000..2f479587c7 --- /dev/null +++ b/src/anomalib/models/image/glass/torch_model.py @@ -0,0 +1,637 @@ +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +import torch.nn.functional as f +from torch import nn + +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator +from anomalib.models.components import TimmFeatureExtractor +from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims + +from .loss import FocalLoss + + +def init_weight(m: nn.Module) -> None: + """Initializes network weights using Xavier normal initialization. + + Applies Xavier initialization for linear layers and normal initialization + for convolutional and batch normalization layers. + """ + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_normal_(m.weight) + if isinstance(m, torch.nn.BatchNorm2d): + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + elif isinstance(m, torch.nn.Conv2d): + m.weight.data.normal_(0.0, 0.02) + + +def _deduce_dims( + feature_extractor: TimmFeatureExtractor, + input_size: tuple[int, int], + layers: list[str], +) -> list[int | tuple[int, int]]: + """Determines feature dimensions for each layer in the feature extractor. + + Args: + feature_extractor: The backbone feature extractor + input_size: Input image dimensions + layers: List of layer names to extract features from + """ + dimensions_mapping = dryrun_find_featuremap_dims( + feature_extractor, + input_size, + layers, + ) + + return [dimensions_mapping[layer]["num_features"] for layer in layers] + + +class Preprocessing(torch.nn.Module): + """Handles initial feature preprocessing across multiple input dimensions. + + Input: List of features from different backbone layers + Output: Processed features with consistent dimensionality + """ + + def __init__(self, input_dims: list[int | tuple[int, int]], output_dim: int) -> None: + super().__init__() + self.input_dims = input_dims + self.output_dim = output_dim + + self.preprocessing_modules = torch.nn.ModuleList() + for _ in input_dims: + module = MeanMapper(output_dim) + self.preprocessing_modules.append(module) + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Applies preprocessing modules to a list of input feature tensors. + + Args: + features (list of torch.Tensor): List of feature maps from different + layers of the backbone network. Each tensor can have a different shape. + + Returns: + torch.Tensor: A single tensor with shape (B, N, D), where B is the batch size, + N is the number of feature maps, and D is the output dimension (`output_dim`). + """ + features_ = [] + for module, feature in zip(self.preprocessing_modules, features, strict=False): + features_.append(module(feature)) + return torch.stack(features_, dim=1) + + +class MeanMapper(torch.nn.Module): + """Maps input features to a fixed dimension using adaptive average pooling. + + Input: Variable-sized feature tensors + Output: Fixed-size feature representations + """ + + def __init__(self, preprocessing_dim: int) -> None: + super().__init__() + self.preprocessing_dim = preprocessing_dim + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Applies adaptive average pooling to reshape features to a fixed size. + + Args: + features (torch.Tensor): Input tensor of shape (B, *) where * denotes + any number of remaining dimensions. It is flattened before pooling. + + Returns: + torch.Tensor: Output tensor of shape (B, D), where D is `preprocessing_dim`. + """ + features = features.reshape(len(features), 1, -1) + return f.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) + + +class Aggregator(torch.nn.Module): + """Aggregates and reshapes features to a target dimension. + + Input: Multi-dimensional feature tensors + Output: Reshaped and pooled features of specified target dimension + """ + + def __init__(self, target_dim: int) -> None: + super().__init__() + self.target_dim = target_dim + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Returns reshaped and average pooled features.""" + features = features.reshape(len(features), 1, -1) + features = f.adaptive_avg_pool1d(features, self.target_dim) + return features.reshape(len(features), -1) + + +class Projection(torch.nn.Module): + """Multi-layer projection network for feature adaptation. + + Args: + in_planes: Input feature dimension + out_planes: Output feature dimension + n_layers: Number of projection layers + layer_type: Type of intermediate layers + """ + + def __init__(self, in_planes: int, out_planes: int | None = None, n_layers: int = 1, layer_type: int = 0) -> None: + super().__init__() + + if out_planes is None: + out_planes = in_planes + self.layers = torch.nn.Sequential() + in_ = None + out = None + for i in range(n_layers): + in_ = in_planes if i == 0 else out + out = out_planes + self.layers.add_module(f"{i}fc", torch.nn.Linear(in_, out)) + if i < n_layers - 1 and layer_type > 1: + self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) + self.apply(init_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the projection network to the input features. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Transformed tensor of shape (B, out_planes). + """ + return self.layers(x) + + +class Discriminator(torch.nn.Module): + """Discriminator network for anomaly detection. + + Args: + in_planes: Input feature dimension + n_layers: Number of layers + hidden: Hidden layer dimensions + """ + + def __init__(self, in_planes: int, n_layers: int = 2, hidden: int | None = None) -> None: + super().__init__() + + hidden_ = in_planes if hidden is None else hidden + self.body = torch.nn.Sequential() + for i in range(n_layers - 1): + in_ = in_planes if i == 0 else hidden_ + hidden_ = int(hidden_ // 1.5) if hidden is None else hidden + self.body.add_module( + f"block{i + 1}", + torch.nn.Sequential( + torch.nn.Linear(in_, hidden_), + torch.nn.BatchNorm1d(hidden_), + torch.nn.LeakyReLU(0.2), + ), + ) + self.tail = torch.nn.Sequential( + torch.nn.Linear(hidden_, 1, bias=False), + torch.nn.Sigmoid(), + ) + self.apply(init_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs a forward pass through the discriminator network. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Output tensor of shape (B, 1) containing probability scores. + """ + x = self.body(x) + return self.tail(x) + + +class PatchMaker: + """Handles patch-based processing of feature maps. + + This class provides utilities for converting feature maps into patches, + reshaping patch scores back to original dimensions, and computing global + anomaly scores from patch-wise predictions. + + Attributes: + patchsize (int): Size of each patch (patchsize x patchsize). + stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. + top_k (int): Number of top patch scores to consider. Used for score reduction. + """ + + def __init__(self, patchsize: int, top_k: int = 0, stride: int | None = None) -> None: + self.patchsize = patchsize + self.stride = stride if stride is not None else patchsize + self.top_k = top_k + + def patchify( + self, + features: torch.Tensor, + return_spatial_info: bool = False, + ) -> tuple[torch.Tensor, list[int]] | torch.Tensor: + """Converts a batch of feature maps into patches. + + Args: + features (torch.Tensor): Input feature maps of shape (B, C, H, W). + return_spatial_info (bool): If True, also returns spatial patch count. Default is False. + + Returns: + torch.Tensor: Output tensor of shape (B, N, C, patchsize, patchsize), where N is number of patches. + list[int], optional: Number of patches in (height, width) dimensions, only if return_spatial_info is True. + """ + padding = int((self.patchsize - 1) / 2) + unfolder = torch.nn.Unfold( + kernel_size=self.patchsize, + stride=self.stride, + padding=padding, + dilation=1, + ) + unfolded_features = unfolder(features) + number_of_total_patches = [] + for s in features.shape[-2:]: + n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 + number_of_total_patches.append(int(n_patches)) + unfolded_features = unfolded_features.reshape( + *features.shape[:2], + self.patchsize, + self.patchsize, + -1, + ) + unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) + + if return_spatial_info: + return unfolded_features, number_of_total_patches + return unfolded_features + + @staticmethod + def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: + """Reshapes patch scores back into per-batch format. + + Args: + x (torch.Tensor): Input tensor of shape (B * N, ...). + batchsize (int): Original batch size. + + Returns: + torch.Tensor: Reshaped tensor of shape (B, N, ...). + """ + return x.reshape(batchsize, -1, *x.shape[1:]) + + @staticmethod + def score(x: torch.Tensor) -> torch.Tensor: + """Computes final anomaly scores from patch-wise predictions. + + Args: + x (torch.Tensor): Patch scores of shape (B, N, 1). + + Returns: + torch.Tensor: Final anomaly score per image, shape (B,). + """ + x = x[:, :, 0] # remove last dimension if singleton + return torch.max(x, dim=1).to_numpy() + + +class GlassModel(nn.Module): + """PyTorch Implementation of the GLASS Model.""" + + def __init__( + self, + input_shape: tuple[int, int], # (H, W) + anomaly_source_path: str, + pretrain_embed_dim: int = 1024, + target_embed_dim: int = 1024, + backbone: str = "resnet18", + patchsize: int = 3, + patchstride: int = 1, + pre_trained: bool = True, + layers: list[str] | None = None, + pre_proj: int = 1, + dsc_layers: int = 2, + dsc_hidden: int = 1024, + dsc_margin: float = 0.5, + mining: int = 1, + noise: float = 0.015, + radius: float = 0.75, + p: float = 0.5, + lr: float = 0.0001, + step: int = 20, + svd: int = 0, + ) -> None: + super().__init__() + + if layers is None: + layers = ["layer1", "layer2", "layer3"] + + self.backbone = backbone + self.layers = layers + self.input_shape = input_shape + self.pre_trained = pre_trained + + self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) + + self.focal_loss = FocalLoss() + + self.forward_modules = torch.nn.ModuleDict({}) + feature_aggregator = TimmFeatureExtractor( + backbone=self.backbone, + layers=self.layers, + pre_trained=self.pre_trained, + ) + feature_dimensions = _deduce_dims(feature_aggregator, self.input_shape, layers) + self.forward_modules["feature_aggregator"] = feature_aggregator + + preprocessing = Preprocessing(feature_dimensions, pretrain_embed_dim) + self.forward_modules["preprocessing"] = preprocessing + self.target_embed_dimension = target_embed_dim + preadapt_aggregator = Aggregator(target_dim=target_embed_dim) + self.forward_modules["preadapt_aggregator"] = preadapt_aggregator + + self.pre_proj = pre_proj + if self.pre_proj > 0: + self.pre_projection = Projection( + self.target_embed_dimension, + self.target_embed_dimension, + pre_proj, + ) + + self.dsc_layers = dsc_layers + self.dsc_hidden = dsc_hidden + self.dsc_margin = dsc_margin + self.discriminator = Discriminator( + self.target_embed_dimension, + n_layers=self.dsc_layers, + hidden=self.dsc_hidden, + ) + + self.p = p + self.radius = radius + self.mining = mining + self.noise = noise + self.distribution = 0 + self.lr = lr + self.step = step + self.svd = svd + + self.patch_maker = PatchMaker(patchsize, stride=patchstride) + + def calculate_mean(self, images: torch.Tensor) -> torch.Tensor: + """Computes the mean feature embedding across a batch of images. + + This method performs a forward pass through the model to extract feature embeddings + for a batch of input images, optionally passing them through a pre-projection module. + It then reshapes the output and calculates the mean across the batch dimension. + + Args: + images (torch.Tensor): Input image tensor of shape (B, C, H, W), where: + - B is the batch size, + - C is the number of channels, + - H and W are height and width. + + Returns: + torch.Tensor: Mean embedding tensor of shape (N, D), where: + - N is the number of patches or tokens per image, + - D is the feature dimension. + """ + self.forward_modules.eval() + with torch.no_grad(): + if self.pre_proj > 0: + outputs = self.pre_projection(self.generate_embeddings(images)[0]) + outputs = outputs[0] if len(outputs) == 2 else outputs + else: + outputs = self._embed(images, evaluation=False)[0] + + outputs = outputs[0] if len(outputs) == 2 else outputs + outputs = outputs.reshape(images.shape[0], -1, outputs.shape[-1]) + + return torch.mean(outputs, dim=0) + + def calculate_features(self, + img: torch.Tensor, + aug: torch.Tensor, + evaluation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate and return feature embeddings for the input and augmented images. + + Depending on whether a pre-projection module is used, this method optionally applies it to the + + Args: + img (torch.Tensor): The original input image tensor. + aug (torch.Tensor): The augmented image tensor. + evaluation (bool, optional): Whether the model is in evaluation mode. Defaults to False. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing the feature embeddings for the original + image (`true_feats`) and the augmented image (`fake_feats`). + """ + if self.pre_proj > 0: + fake_feats = self.pre_projection( + self.generate_embeddings(aug, evaluation=evaluation)[0], + ) + fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats + true_feats = self.pre_projection( + self.generate_embeddings(img, evaluation=evaluation)[0], + ) + true_feats = true_feats[0] if len(true_feats) == 2 else true_feats + else: + fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] + fake_feats.requires_grad = True + true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] + true_feats.requires_grad = True + + return true_feats, fake_feats + + def generate_embeddings( + self, + images: torch.Tensor, + evaluation: bool = False, + ) -> tuple[list[torch.Tensor], list[tuple[int, int]]]: + """Generates patch-wise feature embeddings for a batch of input images. + + This method performs a forward pass through the model's feature extraction pipeline, + processes selected intermediate layers, reshapes them into patches, aligns their spatial sizes, + and passes them through preprocessing and aggregation modules. + + Args: + images (torch.Tensor): Input images of shape (B, C, H, W), where: + - B is the batch size, + - C is the number of channels, + - H and W are the image height and width. + evaluation (bool, optional): Whether to run in evaluation mode (disabling gradients). + Default is False. + + Returns: + tuple[list[torch.Tensor], list[tuple[int, int]]]: + - A list of patch-level feature tensors, each of shape (N, D, P, P), + where N is the number of patches, D is the channel dimension, and P is patch size. + - A list of (height, width) tuples indicating the number of patches in each spatial dimension + for each corresponding feature level. + """ + if not evaluation and not self.pre_trained: + self.forward_modules["feature_aggregator"].train() + features = self.forward_modules["feature_aggregator"](images) + else: + self.forward_modules["feature_aggregator"].eval() + with torch.no_grad(): + features = self.forward_modules["feature_aggregator"](images) + + features = [features[layer] for layer in self.layers] + for i, feat in enumerate(features): + if len(feat.shape) == 3: + B, L, C = feat.shape # noqa: N806 + features[i] = feat.reshape( + B, + int(math.sqrt(L)), + int(math.sqrt(L)), + C, + ).permute(0, 3, 1, 2) + + features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features] + patch_shapes = [x[1] for x in features] + patch_features = [x[0] for x in features] + ref_num_patches = patch_shapes[0] + + for i in range(1, len(patch_features)): + features_ = patch_features[i] + patch_dims = patch_shapes[i] + + features_ = features_.reshape( + features_.shape[0], + patch_dims[0], + patch_dims[1], + *features_.shape[2:], + ) + features_ = features_.permute(0, 3, 4, 5, 1, 2) + perm_base_shape = features_.shape + features_ = features_.reshape(-1, *features_.shape[-2:]) + features_ = f.interpolate( + features_.unsqueeze(1), + size=(ref_num_patches[0], ref_num_patches[1]), + mode="bilinear", + align_corners=False, + ) + features_ = features_.squeeze(1) + features_ = features_.reshape( + *perm_base_shape[:-2], + ref_num_patches[0], + ref_num_patches[1], + ) + features_ = features_.permute(0, 4, 5, 1, 2, 3) + features_ = features_.reshape(len(features_), -1, *features_.shape[-3:]) + patch_features[i] = features_ + + patch_features = [x.reshape(-1, *x.shape[-3:]) for x in patch_features] + patch_features = self.forward_modules["preprocessing"](patch_features) + patch_features = self.forward_modules["preadapt_aggregator"](patch_features) + + return patch_features, patch_shapes + + def forward( + self, + img: torch.Tensor, + c: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass to compute patch-wise feature embeddings for original and augmented images. + + Depending on whether a pre-projection module is used, this method optionally applies it to the + embeddings generated for both `img` and `aug`. If not, the embeddings are directly obtained and + `requires_grad` is enabled for them, likely for gradient-based optimization or anomaly generation. + """ + device = img.device + aug, mask_s = self.augmentor(img) + if img is not None: + batch_size = img.shape[0] + + true_feats, fake_feats = self.calculate_features(img, aug) + + h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + + mask_s_resized = f.interpolate( + mask_s.float(), + size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio), + mode="nearest", + ) + mask_s_gt = mask_s_resized.reshape(-1, 1) + + noise = torch.normal(0, self.noise, true_feats.shape).to(device) + gaus_feats = true_feats + noise + + center = c.repeat(img.shape[0], 1, 1) + center = center.reshape(-1, center.shape[-1]) + true_points = torch.concat( + [fake_feats[mask_s_gt[:, 0] == 0], true_feats], + dim=0, + ) + c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) + dist_t = torch.norm(true_points - c_t_points, dim=1) + r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(device) + + for step in range(self.step + 1): + scores = self.discriminator(torch.cat([true_feats, gaus_feats])) + true_scores = scores[: len(true_feats)] + gaus_scores = scores[len(true_feats) :] + true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores)) + gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores)) + bce_loss = true_loss + gaus_loss + + if step == self.step: + break + + grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] + grad_norm = torch.norm(grad, dim=1) + grad_norm = grad_norm.view(-1, 1) + grad_normalized = grad / (grad_norm + 1e-10) + + with torch.no_grad(): + gaus_feats.add_(0.001 * grad_normalized) + + fake_points = fake_feats[mask_s_gt[:, 0] == 1] + true_points = true_feats[mask_s_gt[:, 0] == 1] + c_f_points = center[mask_s_gt[:, 0] == 1] + dist_f = torch.norm(fake_points - c_f_points, dim=1) + proj_feats = c_f_points if self.svd == 1 else true_points + r = r_t if self.svd == 1 else 1 + + if self.svd == 1: + h = fake_points - proj_feats + h_norm = dist_f if self.svd == 1 else torch.norm(h, dim=1) + alpha = torch.clamp(h_norm, 2 * r, 4 * r) + proj = (alpha / (h_norm + 1e-10)).view(-1, 1) + h = proj * h + fake_points = proj_feats + h + fake_feats[mask_s_gt[:, 0] == 1] = fake_points + + fake_scores = self.discriminator(fake_feats) + + if self.p > 0: + fake_dist = (fake_scores - mask_s_gt) ** 2 + d_hard = torch.quantile(fake_dist, q=self.p) + fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) + mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) + else: + fake_scores_ = fake_scores + mask_ = mask_s_gt + output = torch.cat([1 - fake_scores_, fake_scores_], dim=1) + focal_loss = self.focal_loss(output, mask_) + + loss = bce_loss + focal_loss + return true_loss, gaus_loss, bce_loss, focal_loss, loss diff --git a/third-party-programs.txt b/third-party-programs.txt index 5eeaca8ea9..751477a0c9 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -46,3 +46,7 @@ terms are listed below. 8. AUPIMO metric implementation is based on the original code Copyright (c) 2023 @jpcbertoldo, https://github.com/jpcbertoldo/aupimo SPDX-License-Identifier: MIT + +9. GLASS Model implementation is based on the original code + Copyright (c) 2024 Qiyu Chen, https://github.com/cqylunlun/GLASS + SPDX-License-Identifier: MIT