Skip to content

Commit 7a3bf3a

Browse files
authored
🐞 fix: remove custom state dict handling in WinClip to fix inference tensor issues (#2630)
Apply ruff changes Signed-off-by: Samet Akcay <[email protected]>
1 parent c026ed7 commit 7a3bf3a

File tree

1 file changed

+0
-50
lines changed

1 file changed

+0
-50
lines changed

src/anomalib/models/image/winclip/lightning_model.py

-50
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,9 @@
2929
# SPDX-License-Identifier: Apache-2.0
3030

3131
import logging
32-
from collections import OrderedDict
3332
from pathlib import Path
34-
from typing import Any
3533

3634
import torch
37-
from lightning.pytorch.trainer.states import TrainerFn
3835
from torch import nn
3936
from torch.utils.data import DataLoader
4037
from torchvision.transforms.v2 import Compose, InterpolationMode, Normalize, Resize
@@ -93,8 +90,6 @@ class WinClip(AnomalibModule):
9390
>>> model = WinClip(class_name="transistor") # doctest: +SKIP
9491
9592
Notes:
96-
- The model automatically excludes CLIP backbone parameters from checkpoints to
97-
reduce size
9893
- Input image size is fixed at 240x240 and cannot be modified
9994
- Uses a custom normalization transform specific to CLIP
10095
@@ -103,8 +98,6 @@ class WinClip(AnomalibModule):
10398
- :class:`PostProcessor`: Default post-processor used by WinCLIP
10499
"""
105100

106-
EXCLUDE_FROM_STATE_DICT = frozenset({"model.clip"})
107-
108101
def __init__(
109102
self,
110103
class_name: str | None = None,
@@ -249,49 +242,6 @@ def learning_type(self) -> LearningType:
249242
"""
250243
return LearningType.FEW_SHOT if self.k_shot else LearningType.ZERO_SHOT
251244

252-
def state_dict(self, **kwargs) -> OrderedDict[str, Any]:
253-
"""Get the state dict of the model.
254-
255-
Removes parameters of the frozen backbone to reduce checkpoint size.
256-
257-
Args:
258-
**kwargs: Additional arguments to pass to parent's state_dict
259-
260-
Returns:
261-
OrderedDict[str, Any]: State dict with backbone parameters removed
262-
"""
263-
state_dict = super().state_dict(**kwargs)
264-
if self._trainer is not None and self.trainer.state.fn in {
265-
TrainerFn.FITTING,
266-
TrainerFn.VALIDATING,
267-
}: # Keep backbone weights if exporting the model
268-
for pattern in self.EXCLUDE_FROM_STATE_DICT:
269-
remove_keys = [key for key in state_dict if key.startswith(pattern)]
270-
for key in remove_keys:
271-
state_dict.pop(key)
272-
return state_dict
273-
274-
def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True) -> Any: # noqa: ANN401
275-
"""Load the state dict of the model.
276-
277-
Restores backbone parameters before loading to ensure correct model initialization.
278-
279-
Args:
280-
state_dict (OrderedDict[str, Any]): State dict to load
281-
strict (bool, optional): Whether to strictly enforce that the keys in
282-
``state_dict`` match the keys returned by this module's
283-
``state_dict()`` function. Defaults to ``True``.
284-
285-
Returns:
286-
Any: Return value from parent's load_state_dict
287-
"""
288-
# restore the parameters of the excluded modules, if any
289-
full_dict = super().state_dict()
290-
for pattern in self.EXCLUDE_FROM_STATE_DICT:
291-
restore_dict = {key: value for key, value in full_dict.items() if key.startswith(pattern)}
292-
state_dict.update(restore_dict)
293-
return super().load_state_dict(state_dict, strict)
294-
295245
@classmethod
296246
def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor:
297247
"""Configure the default pre-processor used by the model.

0 commit comments

Comments
 (0)