-
Notifications
You must be signed in to change notification settings - Fork 762
π feat(model): add GLASS model into Anomalib #2629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
code-dev05
wants to merge
16
commits into
open-edge-platform:main
Choose a base branch
from
code-dev05:feature/model/glass
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 8 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
5b4931b
Initial Implementation of GLASS Model
code-dev05 4789f49
Created the trainer class for glass model
code-dev05 050fd4c
Added suggested changes
code-dev05 cdd0984
Modified forward method for model
code-dev05 381eec6
Fixed backbone loading logic
code-dev05 9b1c51a
Added type for input shape
code-dev05 161005c
Fixed bugs
code-dev05 3d78beb
Merge branch 'main' into feature/model/glass
samet-akcay 617cf49
Changed files as needed
code-dev05 f9d3207
Merge remote-tracking branch 'origin/feature/model/glass' into featurβ¦
code-dev05 7fea20f
Matched code to the original implementation
code-dev05 1beedf5
Added support for gpu
code-dev05 838bc50
Refactored code from lightning model to torch model
code-dev05 1baa0b7
GPU bug fixed
code-dev05 f066b3c
used image device in torch model
code-dev05 6e780b0
fixed bug
code-dev05 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
src/anomalib/models/components/feature_extractors/network_feature_extractor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
code-dev05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from torch import nn | ||
import copy | ||
|
||
|
||
class NetworkFeatureAggregator(torch.nn.Module): | ||
"""Efficient extraction of network features.""" | ||
|
||
def __init__(self, backbone, layers_to_extract_from, pre_trained=False): | ||
super(NetworkFeatureAggregator, self).__init__() | ||
"""Extraction of network features. | ||
|
||
Runs a network only to the last layer of the list of layers where | ||
network features should be extracted from. | ||
|
||
Args: | ||
backbone: torchvision.model | ||
layers_to_extract_from: [list of str] | ||
""" | ||
self.layers_to_extract_from = layers_to_extract_from | ||
self.backbone = backbone | ||
self.pre_trained = pre_trained | ||
if not hasattr(backbone, "hook_handles"): | ||
self.backbone.hook_handles = [] | ||
for handle in self.backbone.hook_handles: | ||
handle.remove() | ||
self.outputs = {} | ||
|
||
for extract_layer in layers_to_extract_from: | ||
self.register_hook(extract_layer) | ||
|
||
|
||
def forward(self, images, eval=True): | ||
self.outputs.clear() | ||
if not self.pre_trained and not eval: | ||
self.backbone(images) | ||
else: | ||
with torch.no_grad(): | ||
try: | ||
_ = self.backbone(images) | ||
except LastLayerToExtractReachedException: | ||
pass | ||
return self.outputs | ||
|
||
def feature_dimensions(self, input_shape): | ||
"""Computes the feature dimensions for all layers given input_shape.""" | ||
_input = torch.ones([1] + list(input_shape)) | ||
_output = self(_input) | ||
return [_output[layer].shape[1] for layer in self.layers_to_extract_from] | ||
|
||
def register_hook(self, layer_name): | ||
module = self.find_module(self.backbone, layer_name) | ||
if module is not None: | ||
forward_hook = ForwardHook( | ||
self.outputs, layer_name, self.layers_to_extract_from[-1] | ||
) | ||
if isinstance(module, torch.nn.Sequential): | ||
hook = module[-1].register_forward_hook(forward_hook) | ||
else: | ||
hook = module.register_forward_hook(forward_hook) | ||
self.backbone.hook_handles.append(hook) | ||
else: | ||
raise ValueError(f"Module {layer_name} not found in the model") | ||
|
||
def find_module(self, model, module_name): | ||
for name, module in model.named_modules(): | ||
if name == module_name: | ||
return module | ||
elif "." in module_name: | ||
father, child = module_name.split(".", 1) | ||
if name == father: | ||
return self.find_module(module, child) | ||
return None | ||
|
||
|
||
class ForwardHook: | ||
def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str): | ||
self.hook_dict = hook_dict | ||
self.layer_name = layer_name | ||
self.raise_exception_to_break = copy.deepcopy( | ||
layer_name == last_layer_to_extract | ||
) | ||
|
||
def __call__(self, module, input, output): | ||
self.hook_dict[self.layer_name] = output | ||
return None | ||
|
||
|
||
class LastLayerToExtractReachedException(Exception): | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (C) 2022-2025 Intel Corporation | ||
code-dev05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .lightning_model import Glass as Glass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (C) 2022-2025 Intel Corporation | ||
code-dev05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torchvision.models as models | ||
code-dev05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import timm | ||
|
||
"""copy from: https://github.com/cqylunlun/GLASS/blob/main/backbones.py | ||
This provides mechanism to import any of the given backbones using its name. | ||
""" | ||
_BACKBONES = { | ||
code-dev05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"alexnet": "models.alexnet(pretrained=True)", | ||
"resnet18": "models.resnet18(pretrained=True)", | ||
"resnet50": "models.resnet50(pretrained=True)", | ||
"resnet101": "models.resnet101(pretrained=True)", | ||
"resnext101": "models.resnext101_32x8d(pretrained=True)", | ||
"resnet200": 'timm.create_model("resnet200", pretrained=True)', | ||
"resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)', | ||
"resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)', | ||
"resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)', | ||
"resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)', | ||
"resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)', | ||
"resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)', | ||
"resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)', | ||
"resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)', | ||
"resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)', | ||
"vgg11": "models.vgg11(pretrained=True)", | ||
"vgg19": "models.vgg19(pretrained=True)", | ||
"vgg19_bn": "models.vgg19_bn(pretrained=True)", | ||
"wideresnet50": "models.wide_resnet50_2(pretrained=True)", | ||
"wideresnet101": "models.wide_resnet101_2(pretrained=True)", | ||
"mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)', | ||
"mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)', | ||
"mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)', | ||
"densenet121": 'timm.create_model("densenet121", pretrained=True)', | ||
"densenet201": 'timm.create_model("densenet201", pretrained=True)', | ||
"inception_v4": 'timm.create_model("inception_v4", pretrained=True)', | ||
"vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)', | ||
"vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)', | ||
"vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)', | ||
"vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)', | ||
"vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)', | ||
"vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)', | ||
"vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)', | ||
"vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)', | ||
"efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)', | ||
"efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)', | ||
"efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)', | ||
"efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)', | ||
"efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)', | ||
"efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)', | ||
"efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)', | ||
} | ||
|
||
|
||
def load(name): | ||
return eval(_BACKBONES[name]) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.