-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Restormer Implementation #8312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Restormer Implementation #8312
Changes from 24 commits
Commits
Show all changes
75 commits
Select commit
Hold shift + click to select a range
3db93ce
Add new pixel unshuffle for SubPixelDownsample class
phisanti 9693e04
Add unit test for pixelunshuffle
phisanti a89f299
Add DownSample Modes
phisanti 450691f
expand pixelunshuffle for 3D
phisanti d0920d8
increase testing for pixelunshuffle
phisanti 1a48d4d
expand pixelunshuffle for 3D images
phisanti fe47807
add SubpixelDownsample and tests
phisanti 86155cd
Add DownSample Class
phisanti 137a7f2
Add tests for Downsample
phisanti fb17baf
add exports to __init__
phisanti 5ff0baa
Include test to compare with Conv + unshuffle from original restormer
phisanti 2566db1
remove relative imports
phisanti ac4047b
Create restormer with Downsampler/Upsampler using monai implementation
phisanti 2b74270
Add channel attention block
phisanti 9b74533
add assembled restormer with MONAI convs for 3D
phisanti 1ab34f6
restormer adapted for 2D/3D
phisanti 4f4c62c
Add unit test for CABlock and the FeedForward layers
phisanti 068688f
remove relative imports
phisanti e2e1070
rename restormer
phisanti 35c7ee4
add unit test restormer
phisanti d8cb6c1
Update documentation and imports for CABlock and FeedForward; add Dow…
phisanti 6d96816
Add licence to pixel_unshuffle test
phisanti 8a688fb
Refactor imports and clean up whitespace in utils and test files and …
phisanti acb818d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6352ba9
DCO Remediation Commit for tisalon <[email protected]>
phisanti c7b1af4
add optional_import to downsample block test
phisanti 8faa5da
rename args and fix imports
phisanti be89958
Using LocalStore in Zarr v3 (#8299)
KumoLiu c17938b
8267 fix normalize intensity (#8286)
advcu987 64613a7
Fix bundle download error from ngc source (#8307)
KumoLiu 5643d4a
Fix deprecated usage in zarr (#8313)
KumoLiu 595674a
update pydicom reader to enable gpu load (#8283)
yiheng-wang-nv c775393
Zarr compression tests only with versions before 3.0 (#8319)
ericspod 61efefb
Sync dev branch with upstream MONAI changes
phisanti 091887b
Clarify input tensor shape in pixelshuffle and pixelunshuffle functio…
phisanti 5d162d0
Refactor downsample mode checks to use enum values for clarity
phisanti f520e99
fix optiona import
phisanti 39d1edf
Refactor layer normalization parameters for consistency and clarity i…
phisanti 5b3d4e1
Enhance documentation for MDTATransformerBlock, OverlapPatchEmbed an…
phisanti 1683b14
run ./runtests.sh --autofix to check formatting
phisanti 232be1c
Refactor OverlapPatchEmbed to inherit from Convolution and streamline…
phisanti d1df8e6
Enhance documentation for FeedForward and CABlock classes, adding arg…
phisanti 78ce56b
code formatting
phisanti 64b203d
Update args naming in unit restormer test for consistency with sugges…
phisanti ce15886
Fix optional import
phisanti 30fad17
require einops for all tests
phisanti 1079d8c
require einops also for test_restormer
phisanti b2b3ddf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 174e968
remove relative impots
phisanti e15a815
fix capitalisation in DownSample documentation networks.rts
phisanti d53d97d
fix capitalisation in SubpixelDownsample documentation
phisanti cae7d96
formatting
phisanti a0afee5
update docstring to mention 2D and 3D cases
phisanti 529e90b
Update type annotations and doctring
phisanti c109029
remove problematic unit test
phisanti 19c30f7
remove problematic unit test
phisanti 0b0e4df
Merge remote-tracking branch 'upstream/dev' into dev
phisanti 55da640
relocate test in the correct place
phisanti 3c2dbc6
Add DownSampleBlock missing tests, Signed-off-by: Santiago Cano-Muniz…
phisanti da0a186
Merge branch 'dev' into dev
phisanti f17e06e
Re-order skipUnless in test_restormer.py, Signed-off-by: Cano-Muniz, …
phisanti 4573ec9
Clarify comments for RESTORMER_CONFIGS in test_restormer.py,
phisanti 8c564aa
Remove duplicated test_CABlock.py as part of codebase cleanup. In add…
phisanti 3e013fe
Refactor test cases in test_restormer.py to conditionally define clas…
phisanti 06be2ef
formatting error in line 237. Solved by updating black from 24.10.0 t…
phisanti c02d794
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7f46ac5
Remove duplicated tests and place the order of the decorators (skipUn…
phisanti baf7541
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] aeebc89
Remove debug print statement for einops availability in test_restorme…
phisanti 7342b84
Address mypy suggestions for type annotations in cablock.py, downsamp…
phisanti c608375
Merge branch 'dev' into dev
ericspod 395ce89
Merge branch 'dev' into dev
KumoLiu 22890d6
Merge branch 'dev' into dev
KumoLiu 5eaf79f
Merge branch 'dev' into dev
ericspod f2f8e34
Merge branch 'dev' into dev
ericspod File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
|
||
from monai.networks.blocks.convolutions import Convolution | ||
|
||
__all__ = ["FeedForward", "CABlock"] | ||
|
||
|
||
class FeedForward(nn.Module): | ||
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. | ||
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.""" | ||
|
||
def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool): | ||
super().__init__() | ||
hidden_features = int(dim * ffn_expansion_factor) | ||
|
||
self.project_in = Convolution( | ||
spatial_dims=spatial_dims, | ||
in_channels=dim, | ||
out_channels=hidden_features * 2, | ||
kernel_size=1, | ||
bias=bias, | ||
conv_only=True, | ||
) | ||
|
||
self.dwconv = Convolution( | ||
spatial_dims=spatial_dims, | ||
in_channels=hidden_features * 2, | ||
out_channels=hidden_features * 2, | ||
kernel_size=3, | ||
strides=1, | ||
padding=1, | ||
groups=hidden_features * 2, | ||
bias=bias, | ||
conv_only=True, | ||
) | ||
|
||
self.project_out = Convolution( | ||
spatial_dims=spatial_dims, | ||
in_channels=hidden_features, | ||
out_channels=dim, | ||
kernel_size=1, | ||
bias=bias, | ||
conv_only=True, | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.project_in(x) | ||
x1, x2 = self.dwconv(x).chunk(2, dim=1) | ||
return self.project_out(F.gelu(x1) * x2) | ||
|
||
|
||
class CABlock(nn.Module): | ||
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention | ||
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise | ||
convolutions for local mixing before attention, achieving linear complexity vs quadratic | ||
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>""" | ||
|
||
def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): | ||
super().__init__() | ||
if flash_attention and not hasattr(F, "scaled_dot_product_attention"): | ||
raise ValueError("Flash attention not available") | ||
if spatial_dims > 3: | ||
raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}") | ||
self.spatial_dims = spatial_dims | ||
self.num_heads = num_heads | ||
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) | ||
self.flash_attention = flash_attention | ||
|
||
self.qkv = Convolution( | ||
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True | ||
) | ||
|
||
self.qkv_dwconv = Convolution( | ||
spatial_dims=spatial_dims, | ||
in_channels=dim * 3, | ||
out_channels=dim * 3, | ||
kernel_size=3, | ||
strides=1, | ||
padding=1, | ||
groups=dim * 3, | ||
bias=bias, | ||
conv_only=True, | ||
) | ||
|
||
self.project_out = Convolution( | ||
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True | ||
) | ||
|
||
self._attention_fn = self._get_attention_fn() | ||
|
||
def _get_attention_fn(self): | ||
if self.flash_attention: | ||
return self._flash_attention | ||
return self._normal_attention | ||
|
||
def _flash_attention(self, q, k, v): | ||
"""Flash attention implementation using scaled dot-product attention.""" | ||
scale = float(self.temperature.mean()) | ||
out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False) | ||
return out | ||
|
||
def _normal_attention(self, q, k, v): | ||
"""Attention matrix multiplication with depth-wise convolutions.""" | ||
attn = (q @ k.transpose(-2, -1)) * self.temperature | ||
attn = attn.softmax(dim=-1) | ||
return attn @ v | ||
|
||
def forward(self, x): | ||
"""Forward pass for MDTA attention. | ||
1. Apply depth-wise convolutions to Q, K, V | ||
2. Reshape Q, K, V for multi-head attention | ||
3. Compute attention matrix using flash or normal attention | ||
4. Reshape and project out attention output""" | ||
spatial_dims = x.shape[2:] | ||
|
||
# Project and mix | ||
qkv = self.qkv_dwconv(self.qkv(x)) | ||
q, k, v = qkv.chunk(3, dim=1) | ||
|
||
# Select attention | ||
if self.spatial_dims == 2: | ||
qkv_to_multihead = "b (head c) h w -> b head c (h w)" | ||
multihead_to_qkv = "b head c (h w) -> b (head c) h w" | ||
else: # dims == 3 | ||
qkv_to_multihead = "b (head c) d h w -> b head c (d h w)" | ||
multihead_to_qkv = "b head c (d h w) -> b (head c) d h w" | ||
|
||
# Reconstruct and project feature map | ||
q = rearrange(q, qkv_to_multihead, head=self.num_heads) | ||
k = rearrange(k, qkv_to_multihead, head=self.num_heads) | ||
v = rearrange(v, qkv_to_multihead, head=self.num_heads) | ||
|
||
q = torch.nn.functional.normalize(q, dim=-1) | ||
k = torch.nn.functional.normalize(k, dim=-1) | ||
|
||
out = self._attention_fn(q, k, v) | ||
out = rearrange( | ||
out, | ||
multihead_to_qkv, | ||
head=self.num_heads, | ||
**dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)), | ||
) | ||
|
||
return self.project_out(out) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.