Skip to content

Commit 5b3d4e1

Browse files
committed
Enhance documentation for MDTATransformerBlock, OverlapPatchEmbed and Restormer class.
1 parent 39d1edf commit 5b3d4e1

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

monai/networks/nets/restormer.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,17 @@
2323
class MDTATransformerBlock(nn.Module):
2424
"""Basic transformer unit combining MDTA and GDFN with skip connections.
2525
Unlike standard transformers that use LayerNorm, this block uses Instance Norm
26-
for better adaptation to image restoration tasks."""
26+
for better adaptation to image restoration tasks.
27+
28+
Args:
29+
spatial_dims: Number of spatial dimensions (2D or 3D)
30+
dim: Number of input channels
31+
num_heads: Number of attention heads
32+
ffn_expansion_factor: Expansion factor for feed-forward network
33+
bias: Whether to use bias in attention layers
34+
layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False.
35+
flash_attention: Whether to use flash attention optimization. Defaults to False.
36+
"""
2737

2838
def __init__(
2939
self,
@@ -50,7 +60,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5060
class OverlapPatchEmbed(nn.Module):
5161
"""Initial feature extraction using overlapped convolutions.
5262
Unlike standard patch embeddings that use non-overlapping patches,
53-
this approach maintains spatial continuity through 3x3 convolutions."""
63+
this approach maintains spatial continuity through 3x3 convolutions.
64+
65+
Args:
66+
spatial_dims: Number of spatial dimensions (2D or 3D)
67+
in_channels: Number of input channels
68+
embed_dim: Dimension of embedded features. Defaults to 48.
69+
bias: Whether to use bias in convolution layer. Defaults to False.
70+
"""
5471

5572
def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False):
5673
super().__init__()
@@ -104,17 +121,23 @@ def __init__(
104121
"""Initialize Restormer model.
105122
106123
Args:
124+
spatial_dims: Number of spatial dimensions (2D or 3D)
107125
in_channels: Number of input image channels
108126
out_channels: Number of output image channels
109-
dim: Base feature dimension
110-
num_blocks: Number of transformer blocks at each scale
111-
num_refinement_blocks: Number of final refinement blocks
112-
heads: Number of attention heads at each scale
113-
ffn_expansion_factor: Expansion factor for feed-forward network
114-
bias: Whether to use bias in convolutions
115-
layer_norm_use_bias: Whether to use bias in layer normalization. Default is True.
116-
dual_pixel_task: Enable dual-pixel specific processing
117-
flash_attention: Use flash attention if available
127+
dim: Base feature dimension. Defaults to 48.
128+
num_blocks: Number of transformer blocks at each scale. Defaults to (1,1,1,1).
129+
heads: Number of attention heads at each scale. Defaults to (1,1,1,1).
130+
num_refinement_blocks: Number of final refinement blocks. Defaults to 4.
131+
ffn_expansion_factor: Expansion factor for feed-forward network. Defaults to 2.66.
132+
bias: Whether to use bias in convolutions. Defaults to False.
133+
layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to True.
134+
dual_pixel_task: Enable dual-pixel specific processing. Defaults to False.
135+
flash_attention: Use flash attention if available. Defaults to False.
136+
137+
Note:
138+
The number of blocks must be greater than 1
139+
The length of num_blocks and heads must be equal
140+
All values in num_blocks must be greater than 0
118141
"""
119142
# Check input parameters
120143
assert len(num_blocks) > 1, "Number of blocks must be greater than 1"

0 commit comments

Comments
 (0)