29
29
# SPDX-License-Identifier: Apache-2.0
30
30
31
31
import logging
32
- from collections import OrderedDict
33
32
from pathlib import Path
34
- from typing import Any
35
33
36
34
import torch
37
- from lightning .pytorch .trainer .states import TrainerFn
38
35
from torch import nn
39
36
from torch .utils .data import DataLoader
40
37
from torchvision .transforms .v2 import Compose , InterpolationMode , Normalize , Resize
@@ -93,8 +90,6 @@ class WinClip(AnomalibModule):
93
90
>>> model = WinClip(class_name="transistor") # doctest: +SKIP
94
91
95
92
Notes:
96
- - The model automatically excludes CLIP backbone parameters from checkpoints to
97
- reduce size
98
93
- Input image size is fixed at 240x240 and cannot be modified
99
94
- Uses a custom normalization transform specific to CLIP
100
95
@@ -103,8 +98,6 @@ class WinClip(AnomalibModule):
103
98
- :class:`PostProcessor`: Default post-processor used by WinCLIP
104
99
"""
105
100
106
- EXCLUDE_FROM_STATE_DICT = frozenset ({"model.clip" })
107
-
108
101
def __init__ (
109
102
self ,
110
103
class_name : str | None = None ,
@@ -249,49 +242,6 @@ def learning_type(self) -> LearningType:
249
242
"""
250
243
return LearningType .FEW_SHOT if self .k_shot else LearningType .ZERO_SHOT
251
244
252
- def state_dict (self , ** kwargs ) -> OrderedDict [str , Any ]:
253
- """Get the state dict of the model.
254
-
255
- Removes parameters of the frozen backbone to reduce checkpoint size.
256
-
257
- Args:
258
- **kwargs: Additional arguments to pass to parent's state_dict
259
-
260
- Returns:
261
- OrderedDict[str, Any]: State dict with backbone parameters removed
262
- """
263
- state_dict = super ().state_dict (** kwargs )
264
- if self ._trainer is not None and self .trainer .state .fn in {
265
- TrainerFn .FITTING ,
266
- TrainerFn .VALIDATING ,
267
- }: # Keep backbone weights if exporting the model
268
- for pattern in self .EXCLUDE_FROM_STATE_DICT :
269
- remove_keys = [key for key in state_dict if key .startswith (pattern )]
270
- for key in remove_keys :
271
- state_dict .pop (key )
272
- return state_dict
273
-
274
- def load_state_dict (self , state_dict : OrderedDict [str , Any ], strict : bool = True ) -> Any : # noqa: ANN401
275
- """Load the state dict of the model.
276
-
277
- Restores backbone parameters before loading to ensure correct model initialization.
278
-
279
- Args:
280
- state_dict (OrderedDict[str, Any]): State dict to load
281
- strict (bool, optional): Whether to strictly enforce that the keys in
282
- ``state_dict`` match the keys returned by this module's
283
- ``state_dict()`` function. Defaults to ``True``.
284
-
285
- Returns:
286
- Any: Return value from parent's load_state_dict
287
- """
288
- # restore the parameters of the excluded modules, if any
289
- full_dict = super ().state_dict ()
290
- for pattern in self .EXCLUDE_FROM_STATE_DICT :
291
- restore_dict = {key : value for key , value in full_dict .items () if key .startswith (pattern )}
292
- state_dict .update (restore_dict )
293
- return super ().load_state_dict (state_dict , strict )
294
-
295
245
@classmethod
296
246
def configure_pre_processor (cls , image_size : tuple [int , int ] | None = None ) -> PreProcessor :
297
247
"""Configure the default pre-processor used by the model.
0 commit comments