43
43
import logging
44
44
import warnings
45
45
from abc import ABC , abstractmethod
46
- from collections .abc import Sequence
46
+ from collections .abc import Callable , Sequence
47
47
from pathlib import Path
48
48
from typing import Any
49
49
@@ -137,10 +137,10 @@ def __init__(
137
137
self .loss : nn .Module
138
138
self .callbacks : list [Callback ]
139
139
140
- self .pre_processor = self ._resolve_pre_processor (pre_processor )
141
- self .post_processor = self ._resolve_post_processor (post_processor )
142
- self .evaluator = self ._resolve_evaluator (evaluator )
143
- self .visualizer = self ._resolve_visualizer (visualizer )
140
+ self .pre_processor = self ._resolve_component (pre_processor , PreProcessor , self . configure_pre_processor )
141
+ self .post_processor = self ._resolve_component (post_processor , PostProcessor , self . configure_post_processor )
142
+ self .evaluator = self ._resolve_component (evaluator , Evaluator , self . configure_evaluator )
143
+ self .visualizer = self ._resolve_component (visualizer , Visualizer , self . configure_visualizer )
144
144
145
145
self ._input_size : tuple [int , int ] | None = None
146
146
self ._is_setup = False
@@ -299,34 +299,42 @@ def learning_type(self) -> LearningType:
299
299
"""
300
300
raise NotImplementedError
301
301
302
- def _resolve_pre_processor (self , pre_processor : PreProcessor | bool ) -> PreProcessor | None :
303
- """Resolve and validate the pre-processor configuration.
302
+ @staticmethod
303
+ def _resolve_component (component : nn .Module | None , component_type : type , default_callable : Callable ) -> nn .Module :
304
+ """Resolve and validate the subcomponent configuration.
305
+
306
+ This method resolves the configuration for various subcomponents like
307
+ pre-processor, post-processor, evaluator and visualizer. It validates
308
+ the configuration and returns the configured component. If the component
309
+ is a boolean, it uses the default callable to create the component. If
310
+ the component is already an instance of the component type, it returns
311
+ the component as is.
304
312
305
313
Args:
306
- pre_processor (PreProcessor | bool): Pre-processor configuration
307
- - ``True`` -> use default pre-processor
308
- - ``False`` -> no pre-processor
309
- - ``PreProcessor`` -> use provided pre-processor
314
+ component (object): Component configuration
315
+ component_type (Type): Type of the component
316
+ default_callable (Callable): Callable to create default component
310
317
311
318
Returns:
312
- PreProcessor | None: Configured pre-processor
319
+ Component | None: Configured component
313
320
314
321
Raises:
315
- TypeError: If pre_processor is invalid type
322
+ TypeError: If component is invalid type
316
323
"""
317
- if isinstance (pre_processor , PreProcessor ):
318
- return pre_processor
319
- if isinstance (pre_processor , bool ):
320
- return self . configure_pre_processor () if pre_processor else None
321
- msg = f"Invalid pre-processor type : { type (pre_processor )} "
324
+ if isinstance (component , component_type ):
325
+ return component
326
+ if isinstance (component , bool ):
327
+ return default_callable () if component else None
328
+ msg = f"Passed object should be { component_type } or bool, got : { type (component )} "
322
329
raise TypeError (msg )
323
330
324
- @classmethod
325
- def configure_pre_processor (cls , image_size : tuple [int , int ] | None = None ) -> PreProcessor :
331
+ @staticmethod
332
+ def configure_pre_processor (image_size : tuple [int , int ] | None = None ) -> PreProcessor :
326
333
"""Configure the default pre-processor.
327
334
328
335
The default pre-processor resizes images and normalizes using ImageNet
329
- statistics.
336
+ statistics. Override this method to provide a custom pre-processor for
337
+ the model.
330
338
331
339
Args:
332
340
image_size (tuple[int, int] | None, optional): Target size for
@@ -348,31 +356,12 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P
348
356
]),
349
357
)
350
358
351
- def _resolve_post_processor (self , post_processor : PostProcessor | bool ) -> PostProcessor | None :
352
- """Resolve and validate the post-processor configuration.
353
-
354
- Args:
355
- post_processor (PostProcessor | bool): Post-processor configuration
356
- - ``True`` -> use default post-processor
357
- - ``False`` -> no post-processor
358
- - ``PostProcessor`` -> use provided post-processor
359
-
360
- Returns:
361
- PostProcessor | None: Configured post-processor
362
-
363
- Raises:
364
- TypeError: If post_processor is invalid type
365
- """
366
- if isinstance (post_processor , PostProcessor ):
367
- return post_processor
368
- if isinstance (post_processor , bool ):
369
- return self .configure_post_processor () if post_processor else None
370
- msg = f"Invalid post-processor type: { type (post_processor )} "
371
- raise TypeError (msg )
372
-
373
359
def configure_post_processor (self ) -> PostProcessor | None :
374
360
"""Configure the default post-processor.
375
361
362
+ The default post-processor is based on the model's learning type. Override
363
+ this method to provide a custom post-processor for the model.
364
+
376
365
Returns:
377
366
PostProcessor | None: Configured post-processor based on learning type
378
367
@@ -394,34 +383,12 @@ def configure_post_processor(self) -> PostProcessor | None:
394
383
)
395
384
raise NotImplementedError (msg )
396
385
397
- def _resolve_evaluator (self , evaluator : Evaluator | bool ) -> Evaluator | None :
398
- """Resolve and validate the evaluator configuration.
399
-
400
- Args:
401
- evaluator (Evaluator | bool): Evaluator configuration
402
- - ``True`` -> use default evaluator
403
- - ``False`` -> no evaluator
404
- - ``Evaluator`` -> use provided evaluator
405
-
406
- Returns:
407
- Evaluator | None: Configured evaluator
408
-
409
- Raises:
410
- TypeError: If evaluator is invalid type
411
- """
412
- if isinstance (evaluator , Evaluator ):
413
- return evaluator
414
- if isinstance (evaluator , bool ):
415
- return self .configure_evaluator () if evaluator else None
416
- msg = f"evaluator must be of type Evaluator or bool, got { type (evaluator )} "
417
- raise TypeError (msg )
418
-
419
386
@staticmethod
420
387
def configure_evaluator () -> Evaluator :
421
388
"""Configure the default evaluator.
422
389
423
390
The default evaluator includes metrics for both image-level and
424
- pixel-level evaluation.
391
+ pixel-level evaluation. Override this method to provide custom metrics for the model.
425
392
426
393
Returns:
427
394
Evaluator: Configured evaluator with default metrics
@@ -438,32 +405,12 @@ def configure_evaluator() -> Evaluator:
438
405
test_metrics = [image_auroc , image_f1score , pixel_auroc , pixel_f1score ]
439
406
return Evaluator (test_metrics = test_metrics )
440
407
441
- def _resolve_visualizer (self , visualizer : Visualizer | bool ) -> Visualizer | None :
442
- """Resolve and validate the visualizer configuration.
443
-
444
- Args:
445
- visualizer (Visualizer | bool): Visualizer configuration
446
- - ``True`` -> use default visualizer
447
- - ``False`` -> no visualizer
448
- - ``Visualizer`` -> use provided visualizer
449
-
450
- Returns:
451
- Visualizer | None: Configured visualizer
452
-
453
- Raises:
454
- TypeError: If visualizer is invalid type
455
- """
456
- if isinstance (visualizer , Visualizer ):
457
- return visualizer
458
- if isinstance (visualizer , bool ):
459
- return self .configure_visualizer () if visualizer else None
460
- msg = f"Visualizer must be of type Visualizer or bool, got { type (visualizer )} "
461
- raise TypeError (msg )
462
-
463
408
@classmethod
464
409
def configure_visualizer (cls ) -> ImageVisualizer :
465
410
"""Configure the default visualizer.
466
411
412
+ Override this method to provide a custom visualizer for the model.
413
+
467
414
Returns:
468
415
ImageVisualizer: Default image visualizer instance
469
416
0 commit comments