1
- from .model import *
1
+ from .arch import HSDT
2
+ from .attention import GSSA , SMFFN , TransformerBlock
3
+ from .sepconv import S3Conv
4
+
5
+ def hsdt ():
6
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
7
+ net .use_2dconv = False
8
+ net .bandwise = False
9
+ return net
10
+
11
+
12
+ def hsdt_4 ():
13
+ net = HSDT (1 , 4 , 5 , [1 , 3 ])
14
+ net .use_2dconv = False
15
+ net .bandwise = False
16
+ return net
17
+
18
+
19
+ def hsdt_8 ():
20
+ net = HSDT (1 , 8 , 5 , [1 , 3 ])
21
+ net .use_2dconv = False
22
+ net .bandwise = False
23
+ return net
24
+
25
+
26
+ def hsdt_24 ():
27
+ net = HSDT (1 , 24 , 5 , [1 , 3 ])
28
+ net .use_2dconv = False
29
+ net .bandwise = False
30
+ return net
31
+
32
+
33
+ def hsdt_32 ():
34
+ net = HSDT (1 , 32 , 5 , [1 , 3 ])
35
+ net .use_2dconv = False
36
+ net .bandwise = False
37
+ return net
38
+
39
+
40
+ def hsdt_deep ():
41
+ net = HSDT (1 , 16 , 7 , [1 , 3 , 5 ])
42
+ net .use_2dconv = False
43
+ net .bandwise = False
44
+ return net
45
+
46
+
47
+ """ Extension
48
+ """
49
+
50
+
51
+ def hsdt_pnp ():
52
+ net = HSDT (2 , 16 , 5 , [1 , 3 ])
53
+ net .use_2dconv = False
54
+ net .bandwise = False
55
+ return net
56
+
57
+
58
+ def hsdt_ssr ():
59
+ from .arch import HSDTSSR
60
+ net = HSDTSSR (1 , 16 , 5 , [1 , 3 ])
61
+ net .use_2dconv = False
62
+ net .bandwise = False
63
+ return net
64
+
65
+
66
+ """ ablations
67
+ """
68
+
69
+
70
+ def hsdt_pixelwise ():
71
+ from . import arch
72
+ from .attention import PixelwiseTransformerBlock
73
+ arch .TransformerBlock = PixelwiseTransformerBlock
74
+
75
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
76
+ net .use_2dconv = False
77
+ net .bandwise = False
78
+ return net
79
+
80
+ # ablation of ffn
81
+
82
+
83
+ def hsdt_ffn ():
84
+ from . import arch
85
+ from .attention import FFNTransformerBlock
86
+ arch .TransformerBlock = FFNTransformerBlock
87
+
88
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
89
+ net .use_2dconv = False
90
+ net .bandwise = False
91
+ return net
92
+
93
+
94
+ def hsdt_ffn_flex ():
95
+ from . import arch
96
+ from .attention import FFNTransformerBlock
97
+ from functools import partial
98
+ arch .TransformerBlock = partial (FFNTransformerBlock , flex = True )
99
+
100
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
101
+ net .use_2dconv = False
102
+ net .bandwise = False
103
+ return net
104
+
105
+
106
+ def hsdt_gdfn ():
107
+ from . import arch
108
+ from .attention import GDFNTransformerBlock
109
+ arch .TransformerBlock = GDFNTransformerBlock
110
+
111
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
112
+ net .use_2dconv = False
113
+ net .bandwise = False
114
+ return net
115
+
116
+
117
+ def hsdt_smffn1 ():
118
+ from . import arch
119
+ from .attention import GFNTransformerBlock
120
+ arch .TransformerBlock = GFNTransformerBlock
121
+
122
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
123
+ net .use_2dconv = False
124
+ net .bandwise = False
125
+ return net
126
+
127
+ # ablation of ssa
128
+
129
+
130
+ def hsdt_ssa ():
131
+ from . import arch
132
+ from .attention import SSATransformerBlock
133
+ arch .TransformerBlock = SSATransformerBlock
134
+
135
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
136
+ net .use_2dconv = False
137
+ net .bandwise = False
138
+ return net
139
+
140
+ # ablation of s3conv
141
+
142
+
143
+ def hsdt_conv3d ():
144
+ from . import arch
145
+ import torch .nn as nn
146
+ arch .Conv3d = nn .Conv3d
147
+
148
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
149
+ net .use_2dconv = False
150
+ net .bandwise = False
151
+ return net
152
+
153
+
154
+ def hsdt_s3conv_sep ():
155
+ from . import arch
156
+ import torch .nn as nn
157
+ from .sepconv import SepConv_DP
158
+ arch .Conv3d = SepConv_DP .of (nn .Conv3d )
159
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
160
+ net .use_2dconv = False
161
+ net .bandwise = False
162
+ return net
163
+
164
+
165
+ def hsdt_s3conv_seq ():
166
+ from . import arch
167
+ import torch .nn as nn
168
+ from .sepconv import S3Conv_Seq
169
+ arch .Conv3d = S3Conv_Seq .of (nn .Conv3d )
170
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
171
+ net .use_2dconv = False
172
+ net .bandwise = False
173
+ return net
174
+
175
+
176
+ def hsdt_s3conv1 ():
177
+ from . import arch
178
+ import torch .nn as nn
179
+ from .sepconv import S3Conv1
180
+ arch .Conv3d = S3Conv1 .of (nn .Conv3d )
181
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
182
+ net .use_2dconv = False
183
+ net .bandwise = False
184
+ return net
185
+
186
+
187
+ """ Break down
188
+ """
189
+
190
+
191
+ def baseline_s3conv ():
192
+ from . import arch
193
+ from .attention import DummyTransformerBlock
194
+ arch .TransformerBlock = DummyTransformerBlock
195
+ arch .UseBN = False
196
+
197
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
198
+ net .use_2dconv = False
199
+ net .bandwise = False
200
+ return net
201
+
202
+
203
+ def baseline_conv3d ():
204
+ from . import arch
205
+ import torch .nn as nn
206
+ arch .Conv3d = nn .Conv3d
207
+ from .attention import DummyTransformerBlock
208
+ arch .TransformerBlock = DummyTransformerBlock
209
+ arch .UseBN = False
210
+
211
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
212
+ net .use_2dconv = False
213
+ net .bandwise = False
214
+ return net
215
+
216
+
217
+ def baseline_gssa ():
218
+ from . import arch
219
+ import torch .nn as nn
220
+ arch .Conv3d = nn .Conv3d
221
+ from .attention import FFNTransformerBlock
222
+ arch .TransformerBlock = FFNTransformerBlock
223
+ arch .UseBN = True
224
+
225
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
226
+ net .use_2dconv = False
227
+ net .bandwise = False
228
+ return net
229
+
230
+
231
+ def baseline_ssa ():
232
+ from . import arch
233
+ import torch .nn as nn
234
+ arch .Conv3d = nn .Conv3d
235
+ from .attention import SSAFFNTransformerBlock
236
+ arch .TransformerBlock = SSAFFNTransformerBlock
237
+ arch .UseBN = True
238
+
239
+ net = HSDT (1 , 16 , 5 , [1 , 3 ])
240
+ net .use_2dconv = False
241
+ net .bandwise = False
242
+ return net
0 commit comments