Skip to content

Commit d1df8e6

Browse files
committed
Enhance documentation for FeedForward and CABlock classes, adding argument descriptions and error handling details.
1 parent 232be1c commit d1df8e6

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

monai/networks/blocks/cablock.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@
2424

2525
class FeedForward(nn.Module):
2626
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
27-
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection."""
27+
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.
28+
29+
Args:
30+
spatial_dims: Number of spatial dimensions (2D or 3D)
31+
dim: Number of input channels
32+
ffn_expansion_factor: Factor to expand hidden features dimension
33+
bias: Whether to use bias in convolution layers
34+
"""
2835

2936
def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool):
3037
super().__init__()
@@ -70,7 +77,19 @@ class CABlock(nn.Module):
7077
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
7178
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
7279
convolutions for local mixing before attention, achieving linear complexity vs quadratic
73-
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>"""
80+
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>
81+
82+
Args:
83+
spatial_dims: Number of spatial dimensions (2D or 3D)
84+
dim: Number of input channels
85+
num_heads: Number of attention heads
86+
bias: Whether to use bias in convolution layers
87+
flash_attention: Whether to use flash attention optimization. Defaults to False.
88+
89+
Raises:
90+
ValueError: If flash attention is not available in current PyTorch version
91+
ValueError: If spatial_dims is greater than 3
92+
"""
7493

7594
def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
7695
super().__init__()

0 commit comments

Comments
 (0)