Skip to content

Commit b07047a

Browse files
committed
update
1 parent d2f1d4a commit b07047a

File tree

18 files changed

+294
-857
lines changed

18 files changed

+294
-857
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
*.mat
2+
*.png
23
checkpoints
34
results
45
data

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Official PyTorch Implementation of [Hybrid Spectral Denoising Transformer with L
44

55
*Zeqiang Lai, [Ying Fu](https://ying-fu.github.io/)*.
66

7-
<img src="asset/arch.png" width="550px"/>
7+
<img src="asset/arch.png" width="600px"/>
88

99
🌟 **Hightlights**
1010

asset/arch.png

117 KB
Loading

hsdt/__init__.py

Lines changed: 242 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,242 @@
1-
from .model import *
1+
from .arch import HSDT
2+
from .attention import GSSA, SMFFN, TransformerBlock
3+
from .sepconv import S3Conv
4+
5+
def hsdt():
6+
net = HSDT(1, 16, 5, [1, 3])
7+
net.use_2dconv = False
8+
net.bandwise = False
9+
return net
10+
11+
12+
def hsdt_4():
13+
net = HSDT(1, 4, 5, [1, 3])
14+
net.use_2dconv = False
15+
net.bandwise = False
16+
return net
17+
18+
19+
def hsdt_8():
20+
net = HSDT(1, 8, 5, [1, 3])
21+
net.use_2dconv = False
22+
net.bandwise = False
23+
return net
24+
25+
26+
def hsdt_24():
27+
net = HSDT(1, 24, 5, [1, 3])
28+
net.use_2dconv = False
29+
net.bandwise = False
30+
return net
31+
32+
33+
def hsdt_32():
34+
net = HSDT(1, 32, 5, [1, 3])
35+
net.use_2dconv = False
36+
net.bandwise = False
37+
return net
38+
39+
40+
def hsdt_deep():
41+
net = HSDT(1, 16, 7, [1, 3, 5])
42+
net.use_2dconv = False
43+
net.bandwise = False
44+
return net
45+
46+
47+
""" Extension
48+
"""
49+
50+
51+
def hsdt_pnp():
52+
net = HSDT(2, 16, 5, [1, 3])
53+
net.use_2dconv = False
54+
net.bandwise = False
55+
return net
56+
57+
58+
def hsdt_ssr():
59+
from .arch import HSDTSSR
60+
net = HSDTSSR(1, 16, 5, [1, 3])
61+
net.use_2dconv = False
62+
net.bandwise = False
63+
return net
64+
65+
66+
""" ablations
67+
"""
68+
69+
70+
def hsdt_pixelwise():
71+
from . import arch
72+
from .attention import PixelwiseTransformerBlock
73+
arch.TransformerBlock = PixelwiseTransformerBlock
74+
75+
net = HSDT(1, 16, 5, [1, 3])
76+
net.use_2dconv = False
77+
net.bandwise = False
78+
return net
79+
80+
# ablation of ffn
81+
82+
83+
def hsdt_ffn():
84+
from . import arch
85+
from .attention import FFNTransformerBlock
86+
arch.TransformerBlock = FFNTransformerBlock
87+
88+
net = HSDT(1, 16, 5, [1, 3])
89+
net.use_2dconv = False
90+
net.bandwise = False
91+
return net
92+
93+
94+
def hsdt_ffn_flex():
95+
from . import arch
96+
from .attention import FFNTransformerBlock
97+
from functools import partial
98+
arch.TransformerBlock = partial(FFNTransformerBlock, flex=True)
99+
100+
net = HSDT(1, 16, 5, [1, 3])
101+
net.use_2dconv = False
102+
net.bandwise = False
103+
return net
104+
105+
106+
def hsdt_gdfn():
107+
from . import arch
108+
from .attention import GDFNTransformerBlock
109+
arch.TransformerBlock = GDFNTransformerBlock
110+
111+
net = HSDT(1, 16, 5, [1, 3])
112+
net.use_2dconv = False
113+
net.bandwise = False
114+
return net
115+
116+
117+
def hsdt_smffn1():
118+
from . import arch
119+
from .attention import GFNTransformerBlock
120+
arch.TransformerBlock = GFNTransformerBlock
121+
122+
net = HSDT(1, 16, 5, [1, 3])
123+
net.use_2dconv = False
124+
net.bandwise = False
125+
return net
126+
127+
# ablation of ssa
128+
129+
130+
def hsdt_ssa():
131+
from . import arch
132+
from .attention import SSATransformerBlock
133+
arch.TransformerBlock = SSATransformerBlock
134+
135+
net = HSDT(1, 16, 5, [1, 3])
136+
net.use_2dconv = False
137+
net.bandwise = False
138+
return net
139+
140+
# ablation of s3conv
141+
142+
143+
def hsdt_conv3d():
144+
from . import arch
145+
import torch.nn as nn
146+
arch.Conv3d = nn.Conv3d
147+
148+
net = HSDT(1, 16, 5, [1, 3])
149+
net.use_2dconv = False
150+
net.bandwise = False
151+
return net
152+
153+
154+
def hsdt_s3conv_sep():
155+
from . import arch
156+
import torch.nn as nn
157+
from .sepconv import SepConv_DP
158+
arch.Conv3d = SepConv_DP.of(nn.Conv3d)
159+
net = HSDT(1, 16, 5, [1, 3])
160+
net.use_2dconv = False
161+
net.bandwise = False
162+
return net
163+
164+
165+
def hsdt_s3conv_seq():
166+
from . import arch
167+
import torch.nn as nn
168+
from .sepconv import S3Conv_Seq
169+
arch.Conv3d = S3Conv_Seq.of(nn.Conv3d)
170+
net = HSDT(1, 16, 5, [1, 3])
171+
net.use_2dconv = False
172+
net.bandwise = False
173+
return net
174+
175+
176+
def hsdt_s3conv1():
177+
from . import arch
178+
import torch.nn as nn
179+
from .sepconv import S3Conv1
180+
arch.Conv3d = S3Conv1.of(nn.Conv3d)
181+
net = HSDT(1, 16, 5, [1, 3])
182+
net.use_2dconv = False
183+
net.bandwise = False
184+
return net
185+
186+
187+
""" Break down
188+
"""
189+
190+
191+
def baseline_s3conv():
192+
from . import arch
193+
from .attention import DummyTransformerBlock
194+
arch.TransformerBlock = DummyTransformerBlock
195+
arch.UseBN = False
196+
197+
net = HSDT(1, 16, 5, [1, 3])
198+
net.use_2dconv = False
199+
net.bandwise = False
200+
return net
201+
202+
203+
def baseline_conv3d():
204+
from . import arch
205+
import torch.nn as nn
206+
arch.Conv3d = nn.Conv3d
207+
from .attention import DummyTransformerBlock
208+
arch.TransformerBlock = DummyTransformerBlock
209+
arch.UseBN = False
210+
211+
net = HSDT(1, 16, 5, [1, 3])
212+
net.use_2dconv = False
213+
net.bandwise = False
214+
return net
215+
216+
217+
def baseline_gssa():
218+
from . import arch
219+
import torch.nn as nn
220+
arch.Conv3d = nn.Conv3d
221+
from .attention import FFNTransformerBlock
222+
arch.TransformerBlock = FFNTransformerBlock
223+
arch.UseBN = True
224+
225+
net = HSDT(1, 16, 5, [1, 3])
226+
net.use_2dconv = False
227+
net.bandwise = False
228+
return net
229+
230+
231+
def baseline_ssa():
232+
from . import arch
233+
import torch.nn as nn
234+
arch.Conv3d = nn.Conv3d
235+
from .attention import SSAFFNTransformerBlock
236+
arch.TransformerBlock = SSAFFNTransformerBlock
237+
arch.UseBN = True
238+
239+
net = HSDT(1, 16, 5, [1, 3])
240+
net.use_2dconv = False
241+
net.bandwise = False
242+
return net

hsdt/model/arch.py renamed to hsdt/arch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
6-
from sync_batchnorm import SynchronizedBatchNorm3d
76

87
from .attention import TransformerBlock
98
from .sepconv import SepConv_DP, SepConv_DP_CA, S3Conv
109

11-
BatchNorm3d = SynchronizedBatchNorm3d
10+
BatchNorm3d = nn.BatchNorm3d
1211
Conv3d = S3Conv.of(nn.Conv3d)
1312
TransformerBlock = TransformerBlock
1413
IsConvImpl = False
@@ -62,6 +61,7 @@ def forward(self, x, xs):
6261

6362

6463
class Decoder(nn.Module):
64+
count = 1
6565
def __init__(self, channels, num_half_layer, sample_idx, Fusion=None):
6666
super(Decoder, self).__init__()
6767
# Decoder

0 commit comments

Comments
 (0)