|
16 | 16 | from anomalib.visualization import Visualizer
|
17 | 17 |
|
18 | 18 | from .loss import FocalLoss
|
19 |
| -from .perlin import PerlinNoise |
20 | 19 | from .torch_model import GlassModel
|
21 | 20 |
|
| 21 | +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator |
| 22 | + |
22 | 23 |
|
23 | 24 | class Glass(AnomalibModule):
|
24 | 25 | def __init__(
|
@@ -54,7 +55,7 @@ def __init__(
|
54 | 55 | visualizer=visualizer,
|
55 | 56 | )
|
56 | 57 |
|
57 |
| - self.perlin = PerlinNoise(anomaly_source_path) |
| 58 | + self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) |
58 | 59 |
|
59 | 60 | self.model = GlassModel(
|
60 | 61 | input_shape=input_shape,
|
@@ -82,45 +83,41 @@ def __init__(
|
82 | 83 |
|
83 | 84 | self.focal_loss = FocalLoss()
|
84 | 85 |
|
85 |
| - def configure_optimizers(self) -> list[optim.Optimizer]: |
86 |
| - optimizers = [] |
87 |
| - if not self.model.pre_trained: |
88 |
| - backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr) |
89 |
| - optimizers.append(backbone_opt) |
| 86 | + if pre_proj > 0: |
| 87 | + self.proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5) |
90 | 88 | else:
|
91 |
| - optimizers.append(None) |
| 89 | + self.proj_opt = None |
92 | 90 |
|
93 |
| - if self.model.pre_proj > 0: |
94 |
| - proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5) |
95 |
| - optimizers.append(proj_opt) |
| 91 | + if not pre_trained: |
| 92 | + self.backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr) |
96 | 93 | else:
|
97 |
| - optimizers.append(None) |
| 94 | + self.backbone_opt = None |
98 | 95 |
|
| 96 | + def configure_optimizers(self) -> list[optim.Optimizer]: |
99 | 97 | dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2)
|
100 |
| - optimizers.append(dsc_opt) |
101 | 98 |
|
102 |
| - return optimizers |
| 99 | + return dsc_opt |
103 | 100 |
|
104 | 101 | def training_step(
|
105 | 102 | self,
|
106 | 103 | batch: Batch,
|
107 | 104 | batch_idx: int,
|
108 | 105 | ) -> STEP_OUTPUT:
|
109 |
| - backbone_opt, proj_opt, dsc_opt = self.optimizers() |
| 106 | + dsc_opt = self.optimizers() |
110 | 107 |
|
111 | 108 | self.model.forward_modules.eval()
|
112 | 109 | if self.model.pre_proj > 0:
|
113 | 110 | self.pre_projection.train()
|
114 | 111 | self.model.discriminator.train()
|
115 | 112 |
|
116 | 113 | dsc_opt.zero_grad()
|
117 |
| - if proj_opt is not None: |
118 |
| - proj_opt.zero_grad() |
119 |
| - if backbone_opt is not None: |
120 |
| - backbone_opt.zero_grad() |
| 114 | + if self.proj_opt is not None: |
| 115 | + self.proj_opt.zero_grad() |
| 116 | + if self.backbone_opt is not None: |
| 117 | + self.backbone_opt.zero_grad() |
121 | 118 |
|
122 | 119 | img = batch.image
|
123 |
| - aug, mask_s = self.perlin(img) |
| 120 | + aug, mask_s = self.augmentor(img) |
124 | 121 |
|
125 | 122 | true_feats, fake_feats = self.model(img, aug)
|
126 | 123 |
|
@@ -191,10 +188,10 @@ def training_step(
|
191 | 188 | loss = bce_loss + focal_loss
|
192 | 189 | loss.backward()
|
193 | 190 |
|
194 |
| - if proj_opt is not None: |
195 |
| - proj_opt.step() |
196 |
| - if backbone_opt is not None: |
197 |
| - backbone_opt.step() |
| 191 | + if self.proj_opt is not None: |
| 192 | + self.proj_opt.step() |
| 193 | + if self.backbone_opt is not None: |
| 194 | + self.backbone_opt.step() |
198 | 195 | dsc_opt.step()
|
199 | 196 |
|
200 | 197 | def on_train_start(self) -> None:
|
|
0 commit comments