Skip to content

Commit 335ab0f

Browse files
Peter KaplinskyPkaps25
authored andcommitted
add layer norm to resnet
1 parent daf2e71 commit 335ab0f

File tree

2 files changed

+43
-17
lines changed

2 files changed

+43
-17
lines changed

monai/networks/nets/resnet.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import torch.nn as nn
2323

2424
from monai.networks.blocks.encoder import BaseEncoder
25-
from monai.networks.layers.factories import Conv, Norm, Pool
26-
from monai.networks.layers.utils import get_act_layer, get_pool_layer
25+
from monai.networks.layers.factories import Conv, Pool
26+
from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer
2727
from monai.utils import ensure_tuple_rep
2828
from monai.utils.module import look_up_option, optional_import
2929

@@ -79,6 +79,7 @@ def __init__(
7979
stride: int = 1,
8080
downsample: nn.Module | partial | None = None,
8181
act: str | tuple = ("relu", {"inplace": True}),
82+
norm: str | tuple = "batch",
8283
) -> None:
8384
"""
8485
Args:
@@ -92,13 +93,13 @@ def __init__(
9293
super().__init__()
9394

9495
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
95-
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
96+
norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)
9697

9798
self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
98-
self.bn1 = norm_type(planes)
99+
self.bn1 = norm_layer
99100
self.act = get_act_layer(name=act)
100101
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
101-
self.bn2 = norm_type(planes)
102+
self.bn2 = norm_layer
102103
self.downsample = downsample
103104
self.stride = stride
104105

@@ -132,6 +133,7 @@ def __init__(
132133
stride: int = 1,
133134
downsample: nn.Module | partial | None = None,
134135
act: str | tuple = ("relu", {"inplace": True}),
136+
norm: str | tuple = "batch",
135137
) -> None:
136138
"""
137139
Args:
@@ -146,14 +148,14 @@ def __init__(
146148
super().__init__()
147149

148150
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
149-
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
151+
norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)
150152

151153
self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)
152-
self.bn1 = norm_type(planes)
154+
self.bn1 = norm_layer(channels=planes)
153155
self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
154-
self.bn2 = norm_type(planes)
156+
self.bn2 = norm_layer(channels=planes)
155157
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
156-
self.bn3 = norm_type(planes * self.expansion)
158+
self.bn3 = norm_layer(channels=planes * self.expansion)
157159
self.act = get_act_layer(name=act)
158160
self.downsample = downsample
159161
self.stride = stride
@@ -226,6 +228,7 @@ def __init__(
226228
feed_forward: bool = True,
227229
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
228230
act: str | tuple = ("relu", {"inplace": True}),
231+
norm: str | tuple = "batch",
229232
) -> None:
230233
super().__init__()
231234

@@ -238,7 +241,6 @@ def __init__(
238241
raise ValueError("Unknown block '%s', use basic or bottleneck" % block)
239242

240243
conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
241-
norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]
242244
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
243245
avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
244246
Pool.ADAPTIVEAVG, spatial_dims
@@ -262,7 +264,9 @@ def __init__(
262264
padding=tuple(k // 2 for k in conv1_kernel_size),
263265
bias=False,
264266
)
265-
self.bn1 = norm_type(self.in_planes)
267+
268+
norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes)
269+
self.bn1 = norm_layer
266270
self.act = get_act_layer(name=act)
267271
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
268272
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
@@ -275,7 +279,7 @@ def __init__(
275279
for m in self.modules():
276280
if isinstance(m, conv_type):
277281
nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
278-
elif isinstance(m, norm_type):
282+
elif isinstance(m, type(norm_layer)):
279283
nn.init.constant_(torch.as_tensor(m.weight), 1)
280284
nn.init.constant_(torch.as_tensor(m.bias), 0)
281285
elif isinstance(m, nn.Linear):
@@ -295,9 +299,9 @@ def _make_layer(
295299
spatial_dims: int,
296300
shortcut_type: str,
297301
stride: int = 1,
302+
norm: str | tuple = "batch",
298303
) -> nn.Sequential:
299304
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
300-
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
301305

302306
downsample: nn.Module | partial | None = None
303307
if stride != 1 or self.in_planes != planes * block.expansion:
@@ -317,18 +321,23 @@ def _make_layer(
317321
stride=stride,
318322
bias=self.bias_downsample,
319323
),
320-
norm_type(planes * block.expansion),
324+
get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion),
321325
)
322326

323327
layers = [
324328
block(
325-
in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
329+
in_planes=self.in_planes,
330+
planes=planes,
331+
spatial_dims=spatial_dims,
332+
stride=stride,
333+
downsample=downsample,
334+
norm=norm,
326335
)
327336
]
328337

329338
self.in_planes = planes * block.expansion
330339
for _i in range(1, blocks):
331-
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
340+
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm))
332341

333342
return nn.Sequential(*layers)
334343

tests/test_resnet.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,26 @@
202202
(1, 3),
203203
]
204204

205+
TEST_CASE_9 = [
206+
{
207+
"block": "bottleneck",
208+
"layers": [3, 4, 6, 3],
209+
"block_inplanes": [64, 128, 256, 512],
210+
"spatial_dims": 1,
211+
"n_input_channels": 2,
212+
"num_classes": 3,
213+
"conv1_t_size": [3],
214+
"conv1_t_stride": 1,
215+
"act": ("relu", {"inplace": False}),
216+
"norm": ("layer", {"normalized_shape": (64, 32)}),
217+
},
218+
(1, 2, 32),
219+
(1, 3),
220+
]
221+
205222
TEST_CASES = []
206223
PRETRAINED_TEST_CASES = []
207-
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
224+
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A, TEST_CASE_3]:
208225
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
209226
TEST_CASES.append([model, *case])
210227
PRETRAINED_TEST_CASES.append([model, *case])

0 commit comments

Comments
 (0)