Skip to content

Commit 5b4931b

Browse files
committed
Initial Implementation of GLASS Model
Signed-off-by: Devansh Agarwal <[email protected]>
1 parent 12c1e84 commit 5b4931b

File tree

5 files changed

+331
-2
lines changed

5 files changed

+331
-2
lines changed

src/anomalib/models/components/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from .base import AnomalibModule, BufferListMixin, DynamicBufferMixin, MemoryBankMixin
4040
from .dimensionality_reduction import PCA, SparseRandomProjection
41-
from .feature_extractors import TimmFeatureExtractor
41+
from .feature_extractors import TimmFeatureExtractor, NetworkFeatureAggregator
4242
from .filters import GaussianBlur2d
4343
from .sampling import KCenterGreedy
4444
from .stats import GaussianKDE, MultiVariateGaussian

src/anomalib/models/components/feature_extractors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from .timm import TimmFeatureExtractor
3030
from .utils import dryrun_find_featuremap_dims
31-
31+
from .network_feature_extractor import NetworkFeatureAggregator
3232
__all__ = [
3333
"dryrun_find_featuremap_dims",
3434
"TimmFeatureExtractor",
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
from torch import nn
3+
import copy
4+
5+
6+
class NetworkFeatureAggregator(torch.nn.Module):
7+
"""Efficient extraction of network features."""
8+
9+
def __init__(self, backbone, layers_to_extract_from, train_backbone=False):
10+
super(NetworkFeatureAggregator, self).__init__()
11+
"""Extraction of network features.
12+
13+
Runs a network only to the last layer of the list of layers where
14+
network features should be extracted from.
15+
16+
Args:
17+
backbone: torchvision.model
18+
layers_to_extract_from: [list of str]
19+
"""
20+
self.layers_to_extract_from = layers_to_extract_from
21+
self.backbone = backbone
22+
self.train_backbone = train_backbone
23+
if not hasattr(backbone, "hook_handles"):
24+
self.backbone.hook_handles = []
25+
for handle in self.backbone.hook_handles:
26+
handle.remove()
27+
self.outputs = {}
28+
29+
for extract_layer in layers_to_extract_from:
30+
self.register_hook(extract_layer)
31+
32+
self.to(self.device)
33+
34+
def forward(self, images, eval=True):
35+
self.outputs.clear()
36+
if self.train_backbone and not eval:
37+
self.backbone(images)
38+
else:
39+
with torch.no_grad():
40+
try:
41+
_ = self.backbone(images)
42+
except LastLayerToExtractReachedException:
43+
pass
44+
return self.outputs
45+
46+
def feature_dimensions(self, input_shape):
47+
"""Computes the feature dimensions for all layers given input_shape."""
48+
_input = torch.ones([1] + list(input_shape)).to(self.device)
49+
_output = self(_input)
50+
return [_output[layer].shape[1] for layer in self.layers_to_extract_from]
51+
52+
def register_hook(self, layer_name):
53+
module = self.find_module(self.backbone, layer_name)
54+
if module is not None:
55+
forward_hook = ForwardHook(
56+
self.outputs, layer_name, self.layers_to_extract_from[-1]
57+
)
58+
if isinstance(module, torch.nn.Sequential):
59+
hook = module[-1].register_forward_hook(forward_hook)
60+
else:
61+
hook = module.register_forward_hook(forward_hook)
62+
self.backbone.hook_handles.append(hook)
63+
else:
64+
raise ValueError(f"Module {layer_name} not found in the model")
65+
66+
def find_module(self, model, module_name):
67+
for name, module in model.named_modules():
68+
if name == module_name:
69+
return module
70+
elif "." in module_name:
71+
father, child = module_name.split(".", 1)
72+
if name == father:
73+
return self.find_module(module, child)
74+
return None
75+
76+
77+
class ForwardHook:
78+
def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
79+
self.hook_dict = hook_dict
80+
self.layer_name = layer_name
81+
self.raise_exception_to_break = copy.deepcopy(
82+
layer_name == last_layer_to_extract
83+
)
84+
85+
def __call__(self, module, input, output):
86+
self.hook_dict[self.layer_name] = output
87+
return None
88+
89+
90+
class LastLayerToExtractReachedException(Exception):
91+
pass
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .torch_model import GlassModel
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
from anomalib.models.components import NetworkFeatureAggregator
5+
import math
6+
7+
def init_weight(m):
8+
if isinstance(m, torch.nn.Linear):
9+
torch.nn.init.xavier_normal_(m.weight)
10+
if isinstance(m, torch.nn.BatchNorm2d):
11+
m.weight.data.normal_(1.0, 0.02)
12+
m.bias.data.fill_(0)
13+
elif isinstance(m, torch.nn.Conv2d):
14+
m.weight.data.normal_(0.0, 0.02)
15+
16+
class Preprocessing(torch.nn.Module):
17+
def __init__(self, input_dims, output_dim):
18+
super(Preprocessing, self).__init__()
19+
self.input_dims = input_dims
20+
self.output_dim = output_dim
21+
22+
self.preprocessing_modules = torch.nn.ModuleList()
23+
for _ in input_dims:
24+
module = MeanMapper(output_dim)
25+
self.preprocessing_modules.append(module)
26+
27+
def forward(self, features):
28+
_features = []
29+
for module, feature in zip(self.preprocessing_modules, features):
30+
_features.append(module(feature))
31+
return torch.stack(_features, dim=1)
32+
33+
34+
class MeanMapper(torch.nn.Module):
35+
def __init__(self, preprocessing_dim):
36+
super(MeanMapper, self).__init__()
37+
self.preprocessing_dim = preprocessing_dim
38+
39+
def forward(self, features):
40+
features = features.reshape(len(features), 1, -1)
41+
return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)
42+
43+
44+
class Aggregator(torch.nn.Module):
45+
def __init__(self, target_dim):
46+
super(Aggregator, self).__init__()
47+
self.target_dim = target_dim
48+
49+
def forward(self, features):
50+
"""Returns reshaped and average pooled features."""
51+
features = features.reshape(len(features), 1, -1)
52+
features = F.adaptive_avg_pool1d(features, self.target_dim)
53+
return features.reshape(len(features), -1)
54+
55+
class Projection(torch.nn.Module):
56+
def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
57+
super(Projection, self).__init__()
58+
59+
if out_planes is None:
60+
out_planes = in_planes
61+
self.layers = torch.nn.Sequential()
62+
_in = None
63+
_out = None
64+
for i in range(n_layers):
65+
_in = in_planes if i == 0 else _out
66+
_out = out_planes
67+
self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out))
68+
if i < n_layers - 1:
69+
if layer_type > 1:
70+
self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(.2))
71+
self.apply(init_weight)
72+
73+
def forward(self, x):
74+
75+
x = self.layers(x)
76+
return x
77+
78+
class Discriminator(torch.nn.Module):
79+
def __init__(self, in_planes, n_layers=2, hidden=None):
80+
super(Discriminator, self).__init__()
81+
82+
_hidden = in_planes if hidden is None else hidden
83+
self.body = torch.nn.Sequential()
84+
for i in range(n_layers - 1):
85+
_in = in_planes if i == 0 else _hidden
86+
_hidden = int(_hidden // 1.5) if hidden is None else hidden
87+
self.body.add_module('block%d' % (i + 1),
88+
torch.nn.Sequential(
89+
torch.nn.Linear(_in, _hidden),
90+
torch.nn.BatchNorm1d(_hidden),
91+
torch.nn.LeakyReLU(0.2)
92+
))
93+
self.tail = torch.nn.Sequential(torch.nn.Linear(_hidden, 1, bias=False),
94+
torch.nn.Sigmoid())
95+
self.apply(init_weight)
96+
97+
def forward(self, x):
98+
x = self.body(x)
99+
x = self.tail(x)
100+
return x
101+
102+
class PatchMaker:
103+
def __init__(self, patchsize, top_k=0, stride=None):
104+
self.patchsize = patchsize
105+
self.stride = stride
106+
self.top_k = top_k
107+
108+
def patchify(self, features, return_spatial_info=False):
109+
padding = int((self.patchsize - 1) / 2)
110+
unfolder = torch.nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1)
111+
unfolded_features = unfolder(features)
112+
number_of_total_patches = []
113+
for s in features.shape[-2:]:
114+
n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1
115+
number_of_total_patches.append(int(n_patches))
116+
unfolded_features = unfolded_features.reshape(
117+
*features.shape[:2], self.patchsize, self.patchsize, -1
118+
)
119+
unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
120+
121+
if return_spatial_info:
122+
return unfolded_features, number_of_total_patches
123+
return unfolded_features
124+
125+
def unpatch_scores(self, x, batchsize):
126+
return x.reshape(batchsize, -1, *x.shape[1:])
127+
128+
def score(self, x):
129+
x = x[:, :, 0]
130+
x = torch.max(x, dim=1).values
131+
return x
132+
133+
class GlassModel(nn.Module):
134+
def __init__(
135+
self,
136+
input_shape,
137+
pretrain_embed_dim,
138+
target_embed_dim,
139+
backbone: nn.Module,
140+
patchsize: int =3,
141+
patchstride: int =1,
142+
pre_trained: bool =True,
143+
layers: list[str] = ["layer1", "layer2", "layer3"],
144+
pre_proj: int = 1,
145+
dsc_layers=2,
146+
dsc_hidden=1024
147+
) -> None:
148+
super().__init__()
149+
self.backbone = backbone
150+
self.layers = layers
151+
self.input_shape = input_shape
152+
153+
self.forward_modules = torch.ModuleDict({})
154+
feature_aggregator = NetworkFeatureAggregator(
155+
self.backbone, self.layers, pre_trained
156+
)
157+
feature_dimensions = feature_aggregator.feature_dimensions(input_shape)
158+
self.forward_modules["feature_aggregator"] = feature_aggregator
159+
160+
preprocessing = Preprocessing(feature_dimensions, pretrain_embed_dim)
161+
self.forward_modules["preprocessing"] = preprocessing
162+
self.target_embed_dimension = target_embed_dim
163+
preadapt_aggregator = Aggregator(target_dim=target_embed_dim)
164+
self.forward_modules["preadapt_aggregator"] = preadapt_aggregator
165+
166+
self.pre_trained = pre_trained
167+
168+
self.pre_proj = pre_proj
169+
if self.pre_proj > 0:
170+
self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj)
171+
172+
self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden)
173+
174+
self.patch_maker = PatchMaker(patchsize, stride=patchstride)
175+
176+
def generate_embeddings(self, images, provide_patch_shapes=False, eval=False):
177+
if not eval and not self.pre_trained:
178+
self.forward_modules["feature_aggregator"].train()
179+
features = self.forward_modules["feature_aggregator"](images, eval=eval)
180+
else:
181+
self.forward_modules["feature_aggregator"].eval()
182+
with torch.no_grad():
183+
features = self.forward_modules["feature_aggregator"](images)
184+
185+
features = [features[layer] for layer in self.layers_to_extract_from]
186+
for i, feat in enumerate(features):
187+
if len(feat.shape) == 3:
188+
B, L, C = feat.shape
189+
features[i] = feat.reshape(B, int(math.sqrt(L)), int(math.sqrt(L)), C).permute(0, 3, 1, 2)
190+
191+
features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features]
192+
patch_shapes = [x[1] for x in features]
193+
patch_features = [x[0] for x in features]
194+
ref_num_patches = patch_shapes[0]
195+
196+
for i in range(1, len(patch_features)):
197+
_features = patch_features[i]
198+
patch_dims = patch_shapes[i]
199+
200+
_features = _features.reshape(
201+
_features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
202+
)
203+
_features = _features.permute(0, 3, 4, 5, 1, 2)
204+
perm_base_shape = _features.shape
205+
_features = _features.reshape(-1, *_features.shape[-2:])
206+
_features = F.interpolate(
207+
_features.unsqueeze(1),
208+
size=(ref_num_patches[0], ref_num_patches[1]),
209+
mode="bilinear",
210+
align_corners=False,
211+
)
212+
_features = _features.squeeze(1)
213+
_features = _features.reshape(
214+
*perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
215+
)
216+
_features = _features.permute(0, 4, 5, 1, 2, 3)
217+
_features = _features.reshape(len(_features), -1, *_features.shape[-3:])
218+
patch_features[i] = _features
219+
220+
patch_features = [x.reshape(-1, *x.shape[-3:]) for x in patch_features]
221+
patch_features = self.forward_modules["preprocessing"](patch_features)
222+
patch_features = self.forward_modules["preadapt_aggregator"](patch_features)
223+
224+
return patch_features, patch_shapes
225+
226+
def forward(self, images, eval=False):
227+
self.forward_modules.eval()
228+
with torch.no_grad():
229+
if self.pre_proj > 0:
230+
outputs = self.pre_proj(self.generate_embeddings(images, eval))
231+
outputs = outputs[0] if len(outputs) == 2 else outputs
232+
else:
233+
outputs = self.generate_embeddings(images, eval)[0]
234+
outputs = outputs[0] if len(outputs) == 2 else outputs
235+
outputs = outputs.reshape(images.shape[0], -1, outputs.shape[-1])
236+
return outputs
237+

0 commit comments

Comments
 (0)