Skip to content

Commit 93dcf46

Browse files
committed
Update model.eval based on the new lightning
Signed-off-by: Samet Akcay <[email protected]>
1 parent 09ffb65 commit 93dcf46

File tree

14 files changed

+24
-39
lines changed

14 files changed

+24
-39
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ core = [
3434
"opencv-python>=4.5.3.56",
3535
"pandas>=1.1.0",
3636
"timm>=0.5.4,<=0.9.16",
37-
"lightning>2,<2.2.0",
38-
"torch>=2,<2.3.0",
37+
"lightning>=2.2",
38+
"torch>=2",
3939
"torchmetrics>=1.3.2",
4040
"open-clip-torch>=2.23.0",
4141
]

src/anomalib/models/image/cflow/lightning_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
# TODO(ashwinvaidya17): LR should be part of optimizer in config.yaml since cflow has custom optimizer.
8585
# CVS-122670
8686
self.learning_rate = lr
87+
self.model.encoder.eval()
8788

8889
def configure_optimizers(self) -> Optimizer:
8990
"""Configure optimizers for each decoder.
@@ -119,7 +120,6 @@ def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -
119120
del args, kwargs # These variables are not used.
120121

121122
opt = self.optimizers()
122-
self.model.encoder.eval()
123123

124124
images: torch.Tensor = batch["image"]
125125
activation = self.model.encoder(images)

src/anomalib/models/image/csflow/lightning_model.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Copyright (C) 2022-2024 Intel Corporation
77
# SPDX-License-Identifier: Apache-2.0
88

9-
109
import logging
1110
from typing import Any
1211

@@ -68,6 +67,7 @@ def _setup(self) -> None:
6867
clamp=self.clamp,
6968
num_channels=self.num_channels,
7069
)
70+
self.model.feature_extractor.eval()
7171

7272
def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
7373
"""Perform the training step of CS-Flow.
@@ -82,7 +82,6 @@ def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -
8282
"""
8383
del args, kwargs # These variables are not used.
8484

85-
self.model.feature_extractor.eval()
8685
z_dist, jacobians = self.model(batch["image"])
8786
loss = self.loss(z_dist, jacobians)
8887
self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True)

src/anomalib/models/image/fastflow/lightning_model.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Copyright (C) 2022-2024 Intel Corporation
77
# SPDX-License-Identifier: Apache-2.0
88

9-
109
from typing import Any
1110

1211
import torch
@@ -52,9 +51,8 @@ def __init__(
5251
self.conv3x3_only = conv3x3_only
5352
self.hidden_ratio = hidden_ratio
5453

55-
self.loss = FastflowLoss()
56-
5754
self.model: FastflowModel
55+
self.loss = FastflowLoss()
5856

5957
def _setup(self) -> None:
6058
if self.input_size is None:

src/anomalib/models/image/padim/lightning_model.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Copyright (C) 2022-2024 Intel Corporation
77
# SPDX-License-Identifier: Apache-2.0
88

9-
109
import logging
1110

1211
import torch
@@ -52,7 +51,8 @@ def __init__(
5251
pre_trained=pre_trained,
5352
layers=layers,
5453
n_features=n_features,
55-
).eval()
54+
)
55+
self.model.feature_extractor.eval()
5656

5757
self.stats: list[torch.Tensor] = []
5858
self.embeddings: list[torch.Tensor] = []
@@ -75,9 +75,7 @@ def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -
7575
"""
7676
del args, kwargs # These variables are not used.
7777

78-
self.model.feature_extractor.eval()
7978
embedding = self.model(batch["image"])
80-
8179
self.embeddings.append(embedding.cpu())
8280

8381
def fit(self) -> None:

src/anomalib/models/image/padim/torch_model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Copyright (C) 2022-2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
76
from random import sample
87
from typing import TYPE_CHECKING
98

@@ -67,8 +66,8 @@ class PadimModel(nn.Module):
6766

6867
def __init__(
6968
self,
70-
layers: list[str],
7169
backbone: str = "resnet18",
70+
layers: list[str] = ["layer1", "layer2", "layer3"], # noqa: B006
7271
pre_trained: bool = True,
7372
n_features: int | None = None,
7473
) -> None:
@@ -77,7 +76,11 @@ def __init__(
7776

7877
self.backbone = backbone
7978
self.layers = layers
80-
self.feature_extractor = TimmFeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained)
79+
self.feature_extractor = TimmFeatureExtractor(
80+
backbone=self.backbone,
81+
layers=layers,
82+
pre_trained=pre_trained,
83+
).eval()
8184
self.n_features_original = sum(self.feature_extractor.out_dims)
8285
self.n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone)
8386
if self.n_features is None:

src/anomalib/models/image/patchcore/lightning_model.py

-2
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -
7878
"""
7979
del args, kwargs # These variables are not used.
8080

81-
self.model.feature_extractor.eval()
8281
embedding = self.model(batch["image"])
83-
8482
self.embeddings.append(embedding)
8583

8684
def fit(self) -> None:

src/anomalib/models/image/patchcore/torch_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
backbone=self.backbone,
5050
pre_trained=pre_trained,
5151
layers=self.layers,
52-
)
52+
).eval()
5353
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
5454
self.anomaly_map_generator = AnomalyMapGenerator()
5555

src/anomalib/models/image/reverse_distillation/lightning_model.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Copyright (C) 2022-2024 Intel Corporation
77
# SPDX-License-Identifier: Apache-2.0
88

9-
109
from collections.abc import Sequence
1110
from typing import Any
1211

@@ -50,9 +49,8 @@ def __init__(
5049
self.layers = layers
5150
self.anomaly_map_mode = anomaly_map_mode
5251

53-
self.loss = ReverseDistillationLoss()
54-
5552
self.model: ReverseDistillationModel
53+
self.loss = ReverseDistillationLoss()
5654

5755
def _setup(self) -> None:
5856
if self.input_size is None:

src/anomalib/models/image/stfpm/lightning_model.py

-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Copyright (C) 2022-2024 Intel Corporation
77
# SPDX-License-Identifier: Apache-2.0
88

9-
109
from collections.abc import Sequence
1110
from typing import Any
1211

@@ -61,7 +60,6 @@ def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -
6160
"""
6261
del args, kwargs # These variables are not used.
6362

64-
self.model.teacher_model.eval()
6563
teacher_features, student_features = self.model.forward(batch["image"])
6664
loss = self.loss(teacher_features, student_features)
6765
self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True)

src/anomalib/models/image/stfpm/torch_model.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Copyright (C) 2022-2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
76
from collections.abc import Sequence
87
from typing import TYPE_CHECKING
98

@@ -36,7 +35,7 @@ def __init__(
3635
self.tiler: Tiler | None = None
3736

3837
self.backbone = backbone
39-
self.teacher_model = TimmFeatureExtractor(backbone=self.backbone, pre_trained=True, layers=layers)
38+
self.teacher_model = TimmFeatureExtractor(backbone=self.backbone, pre_trained=True, layers=layers).eval()
4039
self.student_model = TimmFeatureExtractor(
4140
backbone=self.backbone,
4241
pre_trained=False,

src/anomalib/models/image/uflow/feature_extraction.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,13 @@ def get_feature_extractor(backbone: str, input_size: tuple[int, int] = (256, 256
3232
msg = f"Feature extractor must be one of {AVAILABLE_EXTRACTORS}."
3333
raise ValueError(msg)
3434

35+
feature_extractor: nn.Module
3536
if backbone in ["resnet18", "wide_resnet50_2"]:
36-
return FeatureExtractor(backbone, input_size, layers=("layer1", "layer2", "layer3"))
37+
feature_extractor = FeatureExtractor(backbone, input_size, layers=("layer1", "layer2", "layer3")).eval()
3738
if backbone == "mcait":
38-
return MCaitFeatureExtractor()
39-
msg = (
40-
"`backbone` must be one of `[mcait, resnet18, wide_resnet50_2]`. These are the only feature extractors tested. "
41-
"It does not mean that other feature extractors will not work."
42-
)
43-
raise ValueError(msg)
39+
feature_extractor = MCaitFeatureExtractor().eval()
40+
41+
return feature_extractor
4442

4543

4644
class FeatureExtractor(TimmFeatureExtractor):

src/anomalib/models/image/uflow/torch_model.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Copyright (C) 2023-2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
76
import torch
87
from FrEIA import framework as ff
98
from FrEIA import modules as fm

tests/integration/model/test_models.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Copyright (C) 2023-2024 Intel Corporation
77
# SPDX-License-Identifier: Apache-2.0
88

9-
109
from pathlib import Path
1110

1211
import pytest
@@ -144,12 +143,10 @@ def test_export(
144143
dataset_path (Path): Root to dataset from fixture.
145144
project_path (Path): Path to temporary project folder from fixture.
146145
"""
147-
if model_name == "reverse_distillation":
148-
# TODO(ashwinvaidya17): Restore this test after fixing reverse distillation
146+
if model_name in ("reverse_distillation", "rkde"):
147+
# TODO(ashwinvaidya17): Restore this test after fixing the issue
149148
# https://github.com/openvinotoolkit/anomalib/issues/1513
150-
pytest.skip("Reverse distillation fails to convert to ONNX")
151-
elif model_name == "rkde" and export_type == ExportType.OPENVINO:
152-
pytest.skip("RKDE fails to convert to OpenVINO")
149+
pytest.skip(f"{model_name} fails to convert to ONNX and OpenVINO")
153150

154151
model, dataset, engine = self._get_objects(
155152
model_name=model_name,

0 commit comments

Comments
 (0)