Skip to content

Commit 8ae05a3

Browse files
committed
remove subset-specific transforms in preprocessor
1 parent f15f0f7 commit 8ae05a3

File tree

16 files changed

+13
-125
lines changed

16 files changed

+13
-125
lines changed

configs/model/cfa.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ model:
88
num_hard_negative_features: 3
99
radius: 1.0e-05
1010

11-
metrics:
12-
pixel: AUROC
13-
1411
trainer:
1512
max_epochs: 30
1613
callbacks:

configs/model/cflow.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ model:
1515
permute_soft: false
1616
lr: 0.0001
1717

18-
metrics:
19-
pixel:
20-
- AUROC
21-
2218
trainer:
2319
max_epochs: 50
2420
callbacks:

configs/model/csflow.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@ model:
66
clamp: 3
77
num_channels: 3
88

9-
metrics:
10-
pixel:
11-
- AUROC
12-
139
trainer:
1410
max_epochs: 240
1511
callbacks:

configs/model/draem.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@ model:
66
sspcab_lambda: 0.1
77
anomaly_source_path: null
88

9-
metrics:
10-
pixel:
11-
- AUROC
12-
139
trainer:
1410
max_epochs: 700
1511
callbacks:

configs/model/dsr.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ model:
44
latent_anomaly_strength: 0.2
55
upsampling_train_ratio: 0.7
66

7-
metrics:
8-
pixel:
9-
- AUROC
10-
117
# PL Trainer Args. Don't add extra parameter here.
128
trainer:
139
max_epochs: 700

configs/model/efficient_ad.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@ model:
88
padding: false
99
pad_maps: true
1010

11-
metrics:
12-
pixel:
13-
- AUROC
14-
1511
trainer:
1612
max_epochs: 1000
1713
max_steps: 70000

configs/model/fastflow.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ model:
77
conv3x3_only: false
88
hidden_ratio: 1.0
99

10-
metrics:
11-
pixel:
12-
- AUROC
13-
1410
trainer:
1511
max_epochs: 500
1612
callbacks:

configs/model/padim.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,3 @@ model:
88
backbone: resnet18
99
pre_trained: true
1010
n_features: null
11-
12-
metrics:
13-
pixel: AUROC

configs/model/reverse_distillation.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ model:
99
anomaly_map_mode: ADD
1010
pre_trained: true
1111

12-
metrics:
13-
pixel:
14-
- AUROC
15-
1612
trainer:
1713
callbacks:
1814
- class_path: lightning.pytorch.callbacks.EarlyStopping

configs/model/stfpm.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ model:
77
- layer2
88
- layer3
99

10-
metrics:
11-
pixel:
12-
- AUROC
13-
1410
trainer:
1511
max_epochs: 100
1612
callbacks:

configs/model/uflow.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ model:
77
affine_subnet_channels_ratio: 1.0
88
backbone: mcait # official: mcait, other extractors tested: resnet18, wide_resnet50_2. Could use others...
99

10-
metrics:
11-
pixel:
12-
- AUROC
13-
1410
# PL Trainer Args. Don't add extra parameter here.
1511
trainer:
1612
max_epochs: 200

src/anomalib/data/datamodules/base/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _update_augmentations(self) -> None:
120120
for subset_name in ["train", "val", "test"]:
121121
subset = getattr(self, f"{subset_name}_data", None)
122122
augmentations = getattr(self, f"{subset_name}_augmentations", None)
123-
model_transform = self.get_nested_attr(self, f"trainer.model.pre_processor.{subset_name}_transform")
123+
model_transform = self.get_nested_attr(self, "trainer.model.pre_processor.transform")
124124
if subset and augmentations:
125125
self._update_subset_augmentations(subset, augmentations, model_transform)
126126

src/anomalib/models/components/base/anomalib_module.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from anomalib.data import Batch, InferenceBatch
2323
from anomalib.metrics import AUROC, F1Score
2424
from anomalib.metrics.evaluator import Evaluator
25-
from anomalib.metrics.threshold import Threshold
2625
from anomalib.post_processing import OneClassPostProcessor, PostProcessor
2726
from anomalib.pre_processing import PreProcessor
2827
from anomalib.visualization import ImageVisualizer, Visualizer
@@ -368,7 +367,7 @@ def input_size(self) -> tuple[int, int] | None:
368367
The effective input size is the size of the input tensor after the transform has been applied. If the transform
369368
is not set, or if the transform does not change the shape of the input tensor, this method will return None.
370369
"""
371-
transform = self.pre_processor.predict_transform if self.pre_processor else None
370+
transform = self.pre_processor.transform if self.pre_processor else None
372371
if transform is None:
373372
return None
374373
dummy_input = torch.zeros(1, 3, 1, 1)
@@ -418,9 +417,6 @@ def from_config(
418417
help="Path to a configuration file in json or yaml format.",
419418
)
420419
model_parser.add_subclass_arguments(AnomalibModule, "model", required=False, fail_untyped=False)
421-
model_parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"])
422-
model_parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
423-
model_parser.add_argument("--metrics.threshold", type=Threshold | str, default="F1AdaptiveThreshold")
424420
model_parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True)
425421
args = ["--config", str(config_path)]
426422
for key, value in kwargs.items():

src/anomalib/models/image/winclip/lightning_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P
199199
Resize((240, 240), antialias=True, interpolation=InterpolationMode.BICUBIC),
200200
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
201201
])
202-
return PreProcessor(val_transform=transform, test_transform=transform)
202+
return PreProcessor(transform=transform)
203203

204204
@staticmethod
205205
def configure_post_processor() -> OneClassPostProcessor:

src/anomalib/pre_processing/pre_processing.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,33 +77,18 @@ class PreProcessor(nn.Module, Callback):
7777

7878
def __init__(
7979
self,
80-
train_transform: Transform | None = None,
81-
val_transform: Transform | None = None,
82-
test_transform: Transform | None = None,
8380
transform: Transform | None = None,
8481
) -> None:
8582
super().__init__()
8683

87-
if transform and any([train_transform, val_transform, test_transform]):
88-
msg = (
89-
"`transforms` cannot be used together with `train_transform`, `val_transform`, `test_transform`.\n"
90-
"If you want to apply the same transform to the training, validation and test data, "
91-
"use only `transforms`. \n"
92-
"Otherwise, specify transforms for training, validation and test individually."
93-
)
94-
raise ValueError(msg)
95-
96-
self.train_transform = train_transform or transform
97-
self.val_transform = val_transform or transform
98-
self.test_transform = test_transform or transform
99-
self.predict_transform = self.test_transform
100-
self.export_transform = get_exportable_transform(self.test_transform)
84+
self.transform = transform
85+
self.export_transform = get_exportable_transform(self.transform)
10186

10287
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Batch, batch_idx: int) -> None:
10388
"""Apply transforms to the batch of tensors during training."""
10489
del trainer, pl_module, batch_idx # Unused
105-
if self.train_transform:
106-
batch.image, batch.gt_mask = self.train_transform(batch.image, batch.gt_mask)
90+
if self.transform:
91+
batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask)
10792

10893
def on_validation_batch_start(
10994
self,
@@ -114,8 +99,8 @@ def on_validation_batch_start(
11499
) -> None:
115100
"""Apply transforms to the batch of tensors during validation."""
116101
del trainer, pl_module, batch_idx # Unused
117-
if self.val_transform:
118-
batch.image, batch.gt_mask = self.val_transform(batch.image, batch.gt_mask)
102+
if self.transform:
103+
batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask)
119104

120105
def on_test_batch_start(
121106
self,
@@ -127,8 +112,8 @@ def on_test_batch_start(
127112
) -> None:
128113
"""Apply transforms to the batch of tensors during testing."""
129114
del trainer, pl_module, batch_idx, dataloader_idx # Unused
130-
if self.test_transform:
131-
batch.image, batch.gt_mask = self.test_transform(batch.image, batch.gt_mask)
115+
if self.transform:
116+
batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask)
132117

133118
def on_predict_batch_start(
134119
self,
@@ -140,8 +125,8 @@ def on_predict_batch_start(
140125
) -> None:
141126
"""Apply transforms to the batch of tensors during prediction."""
142127
del trainer, pl_module, batch_idx, dataloader_idx # Unused
143-
if self.predict_transform:
144-
batch.image, batch.gt_mask = self.predict_transform(batch.image, batch.gt_mask)
128+
if self.transform:
129+
batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask)
145130

146131
def forward(self, batch: torch.Tensor) -> torch.Tensor:
147132
"""Apply transforms to the batch of tensors for inference.

tests/unit/pre_processing/test_pre_processing.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,6 @@ def setup(self) -> None:
2323
self.dummy_batch = ImageBatch(image=image, gt_mask=gt_mask)
2424
self.common_transform = Compose([Resize((224, 224)), ToImage(), ToDtype(torch.float32, scale=True)])
2525

26-
def test_init(self) -> None:
27-
"""Test the initialization of the PreProcessor class."""
28-
# Test with stage-specific transforms
29-
train_transform = Compose([Resize((224, 224)), ToImage(), ToDtype(torch.float32, scale=True)])
30-
val_transform = Compose([Resize((256, 256)), ToImage(), ToDtype(torch.float32, scale=True)])
31-
pre_processor = PreProcessor(train_transform=train_transform, val_transform=val_transform)
32-
assert pre_processor.train_transform == train_transform
33-
assert pre_processor.val_transform == val_transform
34-
assert pre_processor.test_transform is None
35-
36-
# Test with single transform for all stages
37-
pre_processor = PreProcessor(transform=self.common_transform)
38-
assert pre_processor.train_transform == self.common_transform
39-
assert pre_processor.val_transform == self.common_transform
40-
assert pre_processor.test_transform == self.common_transform
41-
42-
# Test error case: both transform and stage-specific transform
43-
with pytest.raises(ValueError, match="`transforms` cannot be used together with"):
44-
PreProcessor(transform=self.common_transform, train_transform=train_transform)
45-
4626
def test_forward(self) -> None:
4727
"""Test the forward method of the PreProcessor class."""
4828
pre_processor = PreProcessor(transform=self.common_transform)
@@ -56,34 +36,3 @@ def test_no_transform(self) -> None:
5636
processed_batch = pre_processor(self.dummy_batch.image)
5737
assert isinstance(processed_batch, torch.Tensor)
5838
assert processed_batch.shape == (1, 3, 256, 256)
59-
60-
@staticmethod
61-
def test_different_stage_transforms() -> None:
62-
"""Test different stage transforms."""
63-
train_transform = Compose([Resize((224, 224)), ToImage(), ToDtype(torch.float32, scale=True)])
64-
val_transform = Compose([Resize((256, 256)), ToImage(), ToDtype(torch.float32, scale=True)])
65-
test_transform = Compose([Resize((288, 288)), ToImage(), ToDtype(torch.float32, scale=True)])
66-
67-
pre_processor = PreProcessor(
68-
train_transform=train_transform,
69-
val_transform=val_transform,
70-
test_transform=test_transform,
71-
)
72-
73-
# Test train transform
74-
test_batch = ImageBatch(image=Image(torch.rand(3, 256, 256)), gt_mask=Mask(torch.zeros(256, 256)))
75-
processed_batch = pre_processor.train_transform(test_batch.image)
76-
assert isinstance(processed_batch, torch.Tensor)
77-
assert processed_batch.shape == (1, 3, 224, 224)
78-
79-
# Test validation transform
80-
test_batch = ImageBatch(image=Image(torch.rand(3, 256, 256)), gt_mask=Mask(torch.zeros(256, 256)))
81-
processed_batch = pre_processor.val_transform(test_batch.image)
82-
assert isinstance(processed_batch, torch.Tensor)
83-
assert processed_batch.shape == (1, 3, 256, 256)
84-
85-
# Test test transform
86-
test_batch = ImageBatch(image=Image(torch.rand(3, 256, 256)), gt_mask=Mask(torch.zeros(256, 256)))
87-
processed_batch = pre_processor.test_transform(test_batch.image)
88-
assert isinstance(processed_batch, torch.Tensor)
89-
assert processed_batch.shape == (1, 3, 288, 288)

0 commit comments

Comments
 (0)