|
24 | 24 |
|
25 | 25 | class FeedForward(nn.Module):
|
26 | 26 | """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
|
27 |
| - Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.""" |
| 27 | + Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection. |
| 28 | + |
| 29 | + Args: |
| 30 | + spatial_dims: Number of spatial dimensions (2D or 3D) |
| 31 | + dim: Number of input channels |
| 32 | + ffn_expansion_factor: Factor to expand hidden features dimension |
| 33 | + bias: Whether to use bias in convolution layers |
| 34 | + """ |
28 | 35 |
|
29 | 36 | def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool):
|
30 | 37 | super().__init__()
|
@@ -70,7 +77,19 @@ class CABlock(nn.Module):
|
70 | 77 | """Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
|
71 | 78 | by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
|
72 | 79 | convolutions for local mixing before attention, achieving linear complexity vs quadratic
|
73 |
| - in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>""" |
| 80 | + in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881> |
| 81 | + |
| 82 | + Args: |
| 83 | + spatial_dims: Number of spatial dimensions (2D or 3D) |
| 84 | + dim: Number of input channels |
| 85 | + num_heads: Number of attention heads |
| 86 | + bias: Whether to use bias in convolution layers |
| 87 | + flash_attention: Whether to use flash attention optimization. Defaults to False. |
| 88 | +
|
| 89 | + Raises: |
| 90 | + ValueError: If flash attention is not available in current PyTorch version |
| 91 | + ValueError: If spatial_dims is greater than 3 |
| 92 | + """ |
74 | 93 |
|
75 | 94 | def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
|
76 | 95 | super().__init__()
|
|
0 commit comments