23
23
class MDTATransformerBlock (nn .Module ):
24
24
"""Basic transformer unit combining MDTA and GDFN with skip connections.
25
25
Unlike standard transformers that use LayerNorm, this block uses Instance Norm
26
- for better adaptation to image restoration tasks."""
26
+ for better adaptation to image restoration tasks.
27
+
28
+ Args:
29
+ spatial_dims: Number of spatial dimensions (2D or 3D)
30
+ dim: Number of input channels
31
+ num_heads: Number of attention heads
32
+ ffn_expansion_factor: Expansion factor for feed-forward network
33
+ bias: Whether to use bias in attention layers
34
+ layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False.
35
+ flash_attention: Whether to use flash attention optimization. Defaults to False.
36
+ """
27
37
28
38
def __init__ (
29
39
self ,
@@ -50,7 +60,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
60
class OverlapPatchEmbed (nn .Module ):
51
61
"""Initial feature extraction using overlapped convolutions.
52
62
Unlike standard patch embeddings that use non-overlapping patches,
53
- this approach maintains spatial continuity through 3x3 convolutions."""
63
+ this approach maintains spatial continuity through 3x3 convolutions.
64
+
65
+ Args:
66
+ spatial_dims: Number of spatial dimensions (2D or 3D)
67
+ in_channels: Number of input channels
68
+ embed_dim: Dimension of embedded features. Defaults to 48.
69
+ bias: Whether to use bias in convolution layer. Defaults to False.
70
+ """
54
71
55
72
def __init__ (self , spatial_dims : int , in_channels : int = 3 , embed_dim : int = 48 , bias : bool = False ):
56
73
super ().__init__ ()
@@ -104,17 +121,23 @@ def __init__(
104
121
"""Initialize Restormer model.
105
122
106
123
Args:
124
+ spatial_dims: Number of spatial dimensions (2D or 3D)
107
125
in_channels: Number of input image channels
108
126
out_channels: Number of output image channels
109
- dim: Base feature dimension
110
- num_blocks: Number of transformer blocks at each scale
111
- num_refinement_blocks: Number of final refinement blocks
112
- heads: Number of attention heads at each scale
113
- ffn_expansion_factor: Expansion factor for feed-forward network
114
- bias: Whether to use bias in convolutions
115
- layer_norm_use_bias: Whether to use bias in layer normalization. Default is True.
116
- dual_pixel_task: Enable dual-pixel specific processing
117
- flash_attention: Use flash attention if available
127
+ dim: Base feature dimension. Defaults to 48.
128
+ num_blocks: Number of transformer blocks at each scale. Defaults to (1,1,1,1).
129
+ heads: Number of attention heads at each scale. Defaults to (1,1,1,1).
130
+ num_refinement_blocks: Number of final refinement blocks. Defaults to 4.
131
+ ffn_expansion_factor: Expansion factor for feed-forward network. Defaults to 2.66.
132
+ bias: Whether to use bias in convolutions. Defaults to False.
133
+ layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to True.
134
+ dual_pixel_task: Enable dual-pixel specific processing. Defaults to False.
135
+ flash_attention: Use flash attention if available. Defaults to False.
136
+
137
+ Note:
138
+ The number of blocks must be greater than 1
139
+ The length of num_blocks and heads must be equal
140
+ All values in num_blocks must be greater than 0
118
141
"""
119
142
# Check input parameters
120
143
assert len (num_blocks ) > 1 , "Number of blocks must be greater than 1"
0 commit comments