Skip to content

Commit b67727f

Browse files
committed
simplify subcomponent resolve in base module
1 parent 58453ef commit b67727f

File tree

1 file changed

+35
-88
lines changed

1 file changed

+35
-88
lines changed

src/anomalib/models/components/base/anomalib_module.py

Lines changed: 35 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import logging
4444
import warnings
4545
from abc import ABC, abstractmethod
46-
from collections.abc import Sequence
46+
from collections.abc import Callable, Sequence
4747
from pathlib import Path
4848
from typing import Any
4949

@@ -137,10 +137,10 @@ def __init__(
137137
self.loss: nn.Module
138138
self.callbacks: list[Callback]
139139

140-
self.pre_processor = self._resolve_pre_processor(pre_processor)
141-
self.post_processor = self._resolve_post_processor(post_processor)
142-
self.evaluator = self._resolve_evaluator(evaluator)
143-
self.visualizer = self._resolve_visualizer(visualizer)
140+
self.pre_processor = self._resolve_component(pre_processor, PreProcessor, self.configure_pre_processor)
141+
self.post_processor = self._resolve_component(post_processor, PostProcessor, self.configure_post_processor)
142+
self.evaluator = self._resolve_component(evaluator, Evaluator, self.configure_evaluator)
143+
self.visualizer = self._resolve_component(visualizer, Visualizer, self.configure_visualizer)
144144

145145
self._input_size: tuple[int, int] | None = None
146146
self._is_setup = False
@@ -299,34 +299,42 @@ def learning_type(self) -> LearningType:
299299
"""
300300
raise NotImplementedError
301301

302-
def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None:
303-
"""Resolve and validate the pre-processor configuration.
302+
@staticmethod
303+
def _resolve_component(component: nn.Module | None, component_type: type, default_callable: Callable) -> nn.Module:
304+
"""Resolve and validate the subcomponent configuration.
305+
306+
This method resolves the configuration for various subcomponents like
307+
pre-processor, post-processor, evaluator and visualizer. It validates
308+
the configuration and returns the configured component. If the component
309+
is a boolean, it uses the default callable to create the component. If
310+
the component is already an instance of the component type, it returns
311+
the component as is.
304312
305313
Args:
306-
pre_processor (PreProcessor | bool): Pre-processor configuration
307-
- ``True`` -> use default pre-processor
308-
- ``False`` -> no pre-processor
309-
- ``PreProcessor`` -> use provided pre-processor
314+
component (object): Component configuration
315+
component_type (Type): Type of the component
316+
default_callable (Callable): Callable to create default component
310317
311318
Returns:
312-
PreProcessor | None: Configured pre-processor
319+
Component | None: Configured component
313320
314321
Raises:
315-
TypeError: If pre_processor is invalid type
322+
TypeError: If component is invalid type
316323
"""
317-
if isinstance(pre_processor, PreProcessor):
318-
return pre_processor
319-
if isinstance(pre_processor, bool):
320-
return self.configure_pre_processor() if pre_processor else None
321-
msg = f"Invalid pre-processor type: {type(pre_processor)}"
324+
if isinstance(component, component_type):
325+
return component
326+
if isinstance(component, bool):
327+
return default_callable() if component else None
328+
msg = f"Passed object should be {component_type} or bool, got: {type(component)}"
322329
raise TypeError(msg)
323330

324-
@classmethod
325-
def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor:
331+
@staticmethod
332+
def configure_pre_processor(image_size: tuple[int, int] | None = None) -> PreProcessor:
326333
"""Configure the default pre-processor.
327334
328335
The default pre-processor resizes images and normalizes using ImageNet
329-
statistics.
336+
statistics. Override this method to provide a custom pre-processor for
337+
the model.
330338
331339
Args:
332340
image_size (tuple[int, int] | None, optional): Target size for
@@ -348,31 +356,12 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P
348356
]),
349357
)
350358

351-
def _resolve_post_processor(self, post_processor: PostProcessor | bool) -> PostProcessor | None:
352-
"""Resolve and validate the post-processor configuration.
353-
354-
Args:
355-
post_processor (PostProcessor | bool): Post-processor configuration
356-
- ``True`` -> use default post-processor
357-
- ``False`` -> no post-processor
358-
- ``PostProcessor`` -> use provided post-processor
359-
360-
Returns:
361-
PostProcessor | None: Configured post-processor
362-
363-
Raises:
364-
TypeError: If post_processor is invalid type
365-
"""
366-
if isinstance(post_processor, PostProcessor):
367-
return post_processor
368-
if isinstance(post_processor, bool):
369-
return self.configure_post_processor() if post_processor else None
370-
msg = f"Invalid post-processor type: {type(post_processor)}"
371-
raise TypeError(msg)
372-
373359
def configure_post_processor(self) -> PostProcessor | None:
374360
"""Configure the default post-processor.
375361
362+
The default post-processor is based on the model's learning type. Override
363+
this method to provide a custom post-processor for the model.
364+
376365
Returns:
377366
PostProcessor | None: Configured post-processor based on learning type
378367
@@ -394,34 +383,12 @@ def configure_post_processor(self) -> PostProcessor | None:
394383
)
395384
raise NotImplementedError(msg)
396385

397-
def _resolve_evaluator(self, evaluator: Evaluator | bool) -> Evaluator | None:
398-
"""Resolve and validate the evaluator configuration.
399-
400-
Args:
401-
evaluator (Evaluator | bool): Evaluator configuration
402-
- ``True`` -> use default evaluator
403-
- ``False`` -> no evaluator
404-
- ``Evaluator`` -> use provided evaluator
405-
406-
Returns:
407-
Evaluator | None: Configured evaluator
408-
409-
Raises:
410-
TypeError: If evaluator is invalid type
411-
"""
412-
if isinstance(evaluator, Evaluator):
413-
return evaluator
414-
if isinstance(evaluator, bool):
415-
return self.configure_evaluator() if evaluator else None
416-
msg = f"evaluator must be of type Evaluator or bool, got {type(evaluator)}"
417-
raise TypeError(msg)
418-
419386
@staticmethod
420387
def configure_evaluator() -> Evaluator:
421388
"""Configure the default evaluator.
422389
423390
The default evaluator includes metrics for both image-level and
424-
pixel-level evaluation.
391+
pixel-level evaluation. Override this method to provide custom metrics for the model.
425392
426393
Returns:
427394
Evaluator: Configured evaluator with default metrics
@@ -438,32 +405,12 @@ def configure_evaluator() -> Evaluator:
438405
test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score]
439406
return Evaluator(test_metrics=test_metrics)
440407

441-
def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None:
442-
"""Resolve and validate the visualizer configuration.
443-
444-
Args:
445-
visualizer (Visualizer | bool): Visualizer configuration
446-
- ``True`` -> use default visualizer
447-
- ``False`` -> no visualizer
448-
- ``Visualizer`` -> use provided visualizer
449-
450-
Returns:
451-
Visualizer | None: Configured visualizer
452-
453-
Raises:
454-
TypeError: If visualizer is invalid type
455-
"""
456-
if isinstance(visualizer, Visualizer):
457-
return visualizer
458-
if isinstance(visualizer, bool):
459-
return self.configure_visualizer() if visualizer else None
460-
msg = f"Visualizer must be of type Visualizer or bool, got {type(visualizer)}"
461-
raise TypeError(msg)
462-
463408
@classmethod
464409
def configure_visualizer(cls) -> ImageVisualizer:
465410
"""Configure the default visualizer.
466411
412+
Override this method to provide a custom visualizer for the model.
413+
467414
Returns:
468415
ImageVisualizer: Default image visualizer instance
469416

0 commit comments

Comments
 (0)