Skip to content

Commit 161005c

Browse files
committed
Fixed bugs
Signed-off-by: Devansh Agarwal <[email protected]>
1 parent 9b1c51a commit 161005c

File tree

3 files changed

+24
-28
lines changed

3 files changed

+24
-28
lines changed

src/anomalib/models/components/feature_extractors/network_feature_extractor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def __init__(self, backbone, layers_to_extract_from, pre_trained=False):
2929
for extract_layer in layers_to_extract_from:
3030
self.register_hook(extract_layer)
3131

32-
self.to(self.device)
3332

3433
def forward(self, images, eval=True):
3534
self.outputs.clear()
@@ -45,7 +44,7 @@ def forward(self, images, eval=True):
4544

4645
def feature_dimensions(self, input_shape):
4746
"""Computes the feature dimensions for all layers given input_shape."""
48-
_input = torch.ones([1] + list(input_shape)).to(self.device)
47+
_input = torch.ones([1] + list(input_shape))
4948
_output = self(_input)
5049
return [_output[layer].shape[1] for layer in self.layers_to_extract_from]
5150

src/anomalib/models/image/glass/lightning_model.py

+21-24
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from anomalib.visualization import Visualizer
1717

1818
from .loss import FocalLoss
19-
from .perlin import PerlinNoise
2019
from .torch_model import GlassModel
2120

21+
from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
22+
2223

2324
class Glass(AnomalibModule):
2425
def __init__(
@@ -54,7 +55,7 @@ def __init__(
5455
visualizer=visualizer,
5556
)
5657

57-
self.perlin = PerlinNoise(anomaly_source_path)
58+
self.augmentor = PerlinAnomalyGenerator(anomaly_source_path)
5859

5960
self.model = GlassModel(
6061
input_shape=input_shape,
@@ -82,45 +83,41 @@ def __init__(
8283

8384
self.focal_loss = FocalLoss()
8485

85-
def configure_optimizers(self) -> list[optim.Optimizer]:
86-
optimizers = []
87-
if not self.model.pre_trained:
88-
backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr)
89-
optimizers.append(backbone_opt)
86+
if pre_proj > 0:
87+
self.proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5)
9088
else:
91-
optimizers.append(None)
89+
self.proj_opt = None
9290

93-
if self.model.pre_proj > 0:
94-
proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5)
95-
optimizers.append(proj_opt)
91+
if not pre_trained:
92+
self.backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr)
9693
else:
97-
optimizers.append(None)
94+
self.backbone_opt = None
9895

96+
def configure_optimizers(self) -> list[optim.Optimizer]:
9997
dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2)
100-
optimizers.append(dsc_opt)
10198

102-
return optimizers
99+
return dsc_opt
103100

104101
def training_step(
105102
self,
106103
batch: Batch,
107104
batch_idx: int,
108105
) -> STEP_OUTPUT:
109-
backbone_opt, proj_opt, dsc_opt = self.optimizers()
106+
dsc_opt = self.optimizers()
110107

111108
self.model.forward_modules.eval()
112109
if self.model.pre_proj > 0:
113110
self.pre_projection.train()
114111
self.model.discriminator.train()
115112

116113
dsc_opt.zero_grad()
117-
if proj_opt is not None:
118-
proj_opt.zero_grad()
119-
if backbone_opt is not None:
120-
backbone_opt.zero_grad()
114+
if self.proj_opt is not None:
115+
self.proj_opt.zero_grad()
116+
if self.backbone_opt is not None:
117+
self.backbone_opt.zero_grad()
121118

122119
img = batch.image
123-
aug, mask_s = self.perlin(img)
120+
aug, mask_s = self.augmentor(img)
124121

125122
true_feats, fake_feats = self.model(img, aug)
126123

@@ -191,10 +188,10 @@ def training_step(
191188
loss = bce_loss + focal_loss
192189
loss.backward()
193190

194-
if proj_opt is not None:
195-
proj_opt.step()
196-
if backbone_opt is not None:
197-
backbone_opt.step()
191+
if self.proj_opt is not None:
192+
self.proj_opt.step()
193+
if self.backbone_opt is not None:
194+
self.backbone_opt.step()
198195
dsc_opt.step()
199196

200197
def on_train_start(self) -> None:

src/anomalib/models/image/glass/torch_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198
self.layers = layers
199199
self.input_shape = input_shape
200200

201-
self.forward_modules = torch.ModuleDict({})
201+
self.forward_modules = torch.nn.ModuleDict({})
202202
feature_aggregator = NetworkFeatureAggregator(
203203
self.backbone,
204204
self.layers,
@@ -257,7 +257,7 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False):
257257
with torch.no_grad():
258258
features = self.forward_modules["feature_aggregator"](images)
259259

260-
features = [features[layer] for layer in self.layers_to_extract_from]
260+
features = [features[layer] for layer in self.layers]
261261
for i, feat in enumerate(features):
262262
if len(feat.shape) == 3:
263263
B, L, C = feat.shape

0 commit comments

Comments
 (0)