22
22
import torch .nn as nn
23
23
24
24
from monai .networks .blocks .encoder import BaseEncoder
25
- from monai .networks .layers .factories import Conv , Norm , Pool
26
- from monai .networks .layers .utils import get_act_layer , get_pool_layer
25
+ from monai .networks .layers .factories import Conv , Pool
26
+ from monai .networks .layers .utils import get_act_layer , get_norm_layer , get_pool_layer
27
27
from monai .utils import ensure_tuple_rep
28
28
from monai .utils .module import look_up_option , optional_import
29
29
@@ -79,6 +79,7 @@ def __init__(
79
79
stride : int = 1 ,
80
80
downsample : nn .Module | partial | None = None ,
81
81
act : str | tuple = ("relu" , {"inplace" : True }),
82
+ norm : str | tuple = "batch" ,
82
83
) -> None :
83
84
"""
84
85
Args:
@@ -92,13 +93,13 @@ def __init__(
92
93
super ().__init__ ()
93
94
94
95
conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
95
- norm_type : Callable = Norm [ Norm . BATCH , spatial_dims ]
96
+ norm_layer = get_norm_layer ( name = norm , spatial_dims = spatial_dims , channels = planes )
96
97
97
98
self .conv1 = conv_type (in_planes , planes , kernel_size = 3 , padding = 1 , stride = stride , bias = False )
98
- self .bn1 = norm_type ( planes )
99
+ self .bn1 = norm_layer
99
100
self .act = get_act_layer (name = act )
100
101
self .conv2 = conv_type (planes , planes , kernel_size = 3 , padding = 1 , bias = False )
101
- self .bn2 = norm_type ( planes )
102
+ self .bn2 = norm_layer
102
103
self .downsample = downsample
103
104
self .stride = stride
104
105
@@ -132,6 +133,7 @@ def __init__(
132
133
stride : int = 1 ,
133
134
downsample : nn .Module | partial | None = None ,
134
135
act : str | tuple = ("relu" , {"inplace" : True }),
136
+ norm : str | tuple = "batch" ,
135
137
) -> None :
136
138
"""
137
139
Args:
@@ -146,14 +148,14 @@ def __init__(
146
148
super ().__init__ ()
147
149
148
150
conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
149
- norm_type : Callable = Norm [ Norm . BATCH , spatial_dims ]
151
+ norm_layer = partial ( get_norm_layer , name = norm , spatial_dims = spatial_dims )
150
152
151
153
self .conv1 = conv_type (in_planes , planes , kernel_size = 1 , bias = False )
152
- self .bn1 = norm_type ( planes )
154
+ self .bn1 = norm_layer ( channels = planes )
153
155
self .conv2 = conv_type (planes , planes , kernel_size = 3 , stride = stride , padding = 1 , bias = False )
154
- self .bn2 = norm_type ( planes )
156
+ self .bn2 = norm_layer ( channels = planes )
155
157
self .conv3 = conv_type (planes , planes * self .expansion , kernel_size = 1 , bias = False )
156
- self .bn3 = norm_type ( planes * self .expansion )
158
+ self .bn3 = norm_layer ( channels = planes * self .expansion )
157
159
self .act = get_act_layer (name = act )
158
160
self .downsample = downsample
159
161
self .stride = stride
@@ -226,6 +228,7 @@ def __init__(
226
228
feed_forward : bool = True ,
227
229
bias_downsample : bool = True , # for backwards compatibility (also see PR #5477)
228
230
act : str | tuple = ("relu" , {"inplace" : True }),
231
+ norm : str | tuple = "batch" ,
229
232
) -> None :
230
233
super ().__init__ ()
231
234
@@ -238,7 +241,6 @@ def __init__(
238
241
raise ValueError ("Unknown block '%s', use basic or bottleneck" % block )
239
242
240
243
conv_type : type [nn .Conv1d | nn .Conv2d | nn .Conv3d ] = Conv [Conv .CONV , spatial_dims ]
241
- norm_type : type [nn .BatchNorm1d | nn .BatchNorm2d | nn .BatchNorm3d ] = Norm [Norm .BATCH , spatial_dims ]
242
244
pool_type : type [nn .MaxPool1d | nn .MaxPool2d | nn .MaxPool3d ] = Pool [Pool .MAX , spatial_dims ]
243
245
avgp_type : type [nn .AdaptiveAvgPool1d | nn .AdaptiveAvgPool2d | nn .AdaptiveAvgPool3d ] = Pool [
244
246
Pool .ADAPTIVEAVG , spatial_dims
@@ -262,7 +264,9 @@ def __init__(
262
264
padding = tuple (k // 2 for k in conv1_kernel_size ),
263
265
bias = False ,
264
266
)
265
- self .bn1 = norm_type (self .in_planes )
267
+
268
+ norm_layer = get_norm_layer (name = norm , spatial_dims = spatial_dims , channels = self .in_planes )
269
+ self .bn1 = norm_layer
266
270
self .act = get_act_layer (name = act )
267
271
self .maxpool = pool_type (kernel_size = 3 , stride = 2 , padding = 1 )
268
272
self .layer1 = self ._make_layer (block , block_inplanes [0 ], layers [0 ], spatial_dims , shortcut_type )
@@ -275,7 +279,7 @@ def __init__(
275
279
for m in self .modules ():
276
280
if isinstance (m , conv_type ):
277
281
nn .init .kaiming_normal_ (torch .as_tensor (m .weight ), mode = "fan_out" , nonlinearity = "relu" )
278
- elif isinstance (m , norm_type ):
282
+ elif isinstance (m , type ( norm_layer ) ):
279
283
nn .init .constant_ (torch .as_tensor (m .weight ), 1 )
280
284
nn .init .constant_ (torch .as_tensor (m .bias ), 0 )
281
285
elif isinstance (m , nn .Linear ):
@@ -295,9 +299,9 @@ def _make_layer(
295
299
spatial_dims : int ,
296
300
shortcut_type : str ,
297
301
stride : int = 1 ,
302
+ norm : str | tuple = "batch" ,
298
303
) -> nn .Sequential :
299
304
conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
300
- norm_type : Callable = Norm [Norm .BATCH , spatial_dims ]
301
305
302
306
downsample : nn .Module | partial | None = None
303
307
if stride != 1 or self .in_planes != planes * block .expansion :
@@ -317,18 +321,23 @@ def _make_layer(
317
321
stride = stride ,
318
322
bias = self .bias_downsample ,
319
323
),
320
- norm_type ( planes * block .expansion ),
324
+ get_norm_layer ( name = norm , spatial_dims = spatial_dims , channels = planes * block .expansion ),
321
325
)
322
326
323
327
layers = [
324
328
block (
325
- in_planes = self .in_planes , planes = planes , spatial_dims = spatial_dims , stride = stride , downsample = downsample
329
+ in_planes = self .in_planes ,
330
+ planes = planes ,
331
+ spatial_dims = spatial_dims ,
332
+ stride = stride ,
333
+ downsample = downsample ,
334
+ norm = norm ,
326
335
)
327
336
]
328
337
329
338
self .in_planes = planes * block .expansion
330
339
for _i in range (1 , blocks ):
331
- layers .append (block (self .in_planes , planes , spatial_dims = spatial_dims ))
340
+ layers .append (block (self .in_planes , planes , spatial_dims = spatial_dims , norm = norm ))
332
341
333
342
return nn .Sequential (* layers )
334
343
0 commit comments