Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit f336d66

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Implement RegNetZ Model (#713)
Summary: Pull Request resolved: #713 Add implementation of RegNetZ models, as per https://arxiv.org/abs/2103.06877 RegNetZ models are trained with a convolutional fully connected head Differential Revision: D27028613 fbshipit-source-id: 3dfffaf06b621ed41b665aae3db3b14185c22783
1 parent 330820f commit f336d66

File tree

4 files changed

+123
-22
lines changed

4 files changed

+123
-22
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
## What's New:
1515

16+
- March 2021: Added [RegNetZ models](https://arxiv.org/abs/2103.06877)
1617
- November 2020: [Vision Transformers](https://openreview.net/forum?id=YicbFdNTTy) now available, with training [recipes](https://github.com/facebookresearch/ClassyVision/tree/master/examples/vit)!
1718

1819
<details>

classy_vision/generic/profiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def flops(self, x):
172172
flops = count1 + count2
173173

174174
# non-linearities:
175-
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]:
175+
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]:
176176
flops = x.numel()
177177

178178
# 2D pooling layers:

classy_vision/models/regnet.py

+106-21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class BlockType(Enum):
2525
VANILLA_BLOCK = auto()
2626
RES_BASIC_BLOCK = auto()
2727
RES_BOTTLENECK_BLOCK = auto()
28+
RES_BOTTLENECK_LINEAR_BLOCK = auto()
2829

2930

3031
# The different possible Stems
@@ -206,8 +207,8 @@ def __init__(
206207
bn_epsilon: float,
207208
bn_momentum: float,
208209
activation: nn.Module,
209-
bot_mul: float,
210210
group_width: int,
211+
bot_mul: float,
211212
se_ratio: Optional[float],
212213
):
213214
super().__init__()
@@ -253,8 +254,8 @@ def __init__(
253254
bn_epsilon: float,
254255
bn_momentum: float,
255256
activation: nn.Module,
256-
bot_mul: float = 1.0,
257257
group_width: int = 1,
258+
bot_mul: float = 1.0,
258259
se_ratio: Optional[float] = None,
259260
):
260261
super().__init__()
@@ -273,8 +274,8 @@ def __init__(
273274
bn_epsilon,
274275
bn_momentum,
275276
activation,
276-
bot_mul,
277277
group_width,
278+
bot_mul,
278279
se_ratio,
279280
)
280281
self.activation = activation
@@ -291,6 +292,41 @@ def forward(self, x, *args):
291292
return self.activation(x)
292293

293294

295+
class ResBottleneckLinearBlock(nn.Module):
296+
"""Residual linear bottleneck block: x + F(x), F = bottleneck transform."""
297+
298+
def __init__(
299+
self,
300+
width_in: int,
301+
width_out: int,
302+
stride: int,
303+
bn_epsilon: float,
304+
bn_momentum: float,
305+
activation: nn.Module,
306+
group_width: int = 1,
307+
bot_mul: float = 4.0,
308+
se_ratio: Optional[float] = None,
309+
):
310+
super().__init__()
311+
self.has_skip = (width_in == width_out) and (stride == 1)
312+
self.f = BottleneckTransform(
313+
width_in,
314+
width_out,
315+
stride,
316+
bn_epsilon,
317+
bn_momentum,
318+
activation,
319+
group_width,
320+
bot_mul,
321+
se_ratio,
322+
)
323+
324+
self.depth = self.f.depth
325+
326+
def forward(self, x):
327+
return x + self.f(x) if self.has_skip else self.f(x)
328+
329+
294330
class AnyStage(nn.Sequential):
295331
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
296332

@@ -302,8 +338,8 @@ def __init__(
302338
depth: int,
303339
block_constructor: nn.Module,
304340
activation: nn.Module,
305-
bot_mul: float,
306341
group_width: int,
342+
bot_mul: float,
307343
params: "RegNetParams",
308344
stage_index: int = 0,
309345
):
@@ -318,8 +354,8 @@ def __init__(
318354
params.bn_epsilon,
319355
params.bn_momentum,
320356
activation,
321-
bot_mul,
322357
group_width,
358+
bot_mul,
323359
params.se_ratio,
324360
)
325361

@@ -354,10 +390,11 @@ def __init__(
354390
w_a: float,
355391
w_m: float,
356392
group_w: int,
357-
stem_type: StemType = "SIMPLE_STEM_IN",
393+
bot_mul: float = 1.0,
394+
stem_type: StemType = StemType.SIMPLE_STEM_IN,
358395
stem_width: int = 32,
359-
block_type: BlockType = "RES_BOTTLENECK_BLOCK",
360-
activation_type: ActivationType = "RELU",
396+
block_type: BlockType = BlockType.RES_BOTTLENECK_BLOCK,
397+
activation: ActivationType = ActivationType.RELU,
361398
use_se: bool = True,
362399
se_ratio: float = 0.25,
363400
bn_epsilon: float = 1e-05,
@@ -371,9 +408,10 @@ def __init__(
371408
self.w_a = w_a
372409
self.w_m = w_m
373410
self.group_w = group_w
374-
self.stem_type = StemType[stem_type]
375-
self.block_type = BlockType[block_type]
376-
self.activation_type = ActivationType[activation_type]
411+
self.bot_mul = bot_mul
412+
self.stem_type = stem_type
413+
self.block_type = block_type
414+
self.activation = activation
377415
self.stem_width = stem_width
378416
self.use_se = use_se
379417
self.se_ratio = se_ratio if use_se else None
@@ -403,7 +441,6 @@ def get_expanded_params(self):
403441

404442
QUANT = 8
405443
STRIDE = 2
406-
BOT_MUL = 1.0
407444

408445
# Compute the block widths. Each stage has one unique block width
409446
widths_cont = np.arange(self.depth) * self.w_a + self.w_0
@@ -428,21 +465,26 @@ def get_expanded_params(self):
428465
stage_depths = np.diff([d for d, t in enumerate(splits) if t]).tolist()
429466

430467
strides = [STRIDE] * num_stages
431-
bot_muls = [BOT_MUL] * num_stages
468+
bot_muls = [self.bot_mul] * num_stages
432469
group_widths = [self.group_w] * num_stages
433470

434471
# Adjust the compatibility of stage widths and group widths
435472
stage_widths, group_widths = _adjust_widths_groups_compatibilty(
436473
stage_widths, bot_muls, group_widths
437474
)
438475

439-
return zip(stage_widths, strides, stage_depths, bot_muls, group_widths)
476+
return zip(stage_widths, strides, stage_depths, group_widths, bot_muls)
440477

441478

442479
@register_model("regnet")
443480
class RegNet(ClassyModel):
444-
"""Implementation of RegNet, a particular form of AnyNets
445-
See https://arxiv.org/abs/2003.13678v1"""
481+
"""Implementation of RegNet, a particular form of AnyNets.
482+
483+
See https://arxiv.org/abs/2003.13678 for introduction to RegNets, and details about
484+
RegNetX and RegNetY models.
485+
486+
See https://arxiv.org/abs/2103.06877 for details about RegNetZ models.
487+
"""
446488

447489
def __init__(self, params: RegNetParams):
448490
super().__init__()
@@ -474,14 +516,15 @@ def __init__(self, params: RegNetParams):
474516
BlockType.VANILLA_BLOCK: VanillaBlock,
475517
BlockType.RES_BASIC_BLOCK: ResBasicBlock,
476518
BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
519+
BlockType.RES_BOTTLENECK_LINEAR_BLOCK: ResBottleneckLinearBlock,
477520
}[params.block_type]
478521

479522
current_width = params.stem_width
480523

481524
self.trunk_depth = 0
482525

483526
blocks = []
484-
for i, (width_out, stride, depth, bot_mul, group_width) in enumerate(
527+
for i, (width_out, stride, depth, group_width, bot_mul) in enumerate(
485528
params.get_expanded_params()
486529
):
487530
blocks.append(
@@ -494,8 +537,8 @@ def __init__(self, params: RegNetParams):
494537
depth,
495538
block_fun,
496539
activation,
497-
bot_mul,
498540
group_width,
541+
bot_mul,
499542
params,
500543
stage_index=i + 1,
501544
),
@@ -529,10 +572,13 @@ def from_config(cls, config: Dict[str, Any]) -> "RegNet":
529572
w_a=config["w_a"],
530573
w_m=config["w_m"],
531574
group_w=config["group_width"],
532-
stem_type=config.get("stem_type", "simple_stem_in").upper(),
575+
bot_mul=config.get("bot_mul", 1.0),
576+
stem_type=StemType[config.get("stem_type", "simple_stem_in").upper()],
533577
stem_width=config.get("stem_width", 32),
534-
block_type=config.get("block_type", "res_bottleneck_block").upper(),
535-
activation_type=config.get("activation_type", "relu").upper(),
578+
block_type=BlockType[
579+
config.get("block_type", "res_bottleneck_block").upper()
580+
],
581+
activation=ActivationType[config.get("activation", "relu").upper()],
536582
use_se=config.get("use_se", True),
537583
se_ratio=config.get("se_ratio", 0.25),
538584
bn_epsilon=config.get("bn_epsilon", 1e-05),
@@ -751,6 +797,45 @@ def __init__(self, **kwargs):
751797
)
752798

753799

800+
# note that RegNetZ models are trained with a convolutional head, i.e. the
801+
# fully_connected ClassyHead with conv_planes > 0.
802+
@register_model("regnet_z_500mf")
803+
class RegNetZ500mf(_RegNet):
804+
def __init__(self, **kwargs):
805+
super().__init__(
806+
RegNetParams(
807+
depth=21,
808+
w_0=16,
809+
w_a=10.7,
810+
w_m=2.51,
811+
group_w=4,
812+
bot_mul=4.0,
813+
block_type=BlockType.RES_BOTTLENECK_LINEAR_BLOCK,
814+
activation=ActivationType.SILU,
815+
**kwargs,
816+
)
817+
)
818+
819+
820+
# this is supposed to be trained with a resolution of 256x256
821+
@register_model("regnet_z_4gf")
822+
class RegNetZ4gf(_RegNet):
823+
def __init__(self, **kwargs):
824+
super().__init__(
825+
RegNetParams(
826+
depth=28,
827+
w_0=48,
828+
w_a=14.5,
829+
w_m=2.226,
830+
group_w=8,
831+
bot_mul=4.0,
832+
block_type=BlockType.RES_BOTTLENECK_LINEAR_BLOCK,
833+
activation=ActivationType.SILU,
834+
**kwargs,
835+
)
836+
)
837+
838+
754839
# -----------------------------------------------------------------------------------
755840
# The following models were not part of the original publication,
756841
# (https://arxiv.org/abs/2003.13678v1), but are larger versions of the

test/models_regnet_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,20 @@
107107
"group_width": 56,
108108
},
109109
),
110+
(
111+
{
112+
# RegNetZ
113+
"name": "regnet",
114+
"block_type": "res_bottleneck_linear_block",
115+
"depth": 21,
116+
"w_0": 16,
117+
"w_a": 10.7,
118+
"w_m": 2.51,
119+
"group_width": 4,
120+
"bot_mul": 4.0,
121+
"activation": "silu",
122+
},
123+
),
110124
]
111125

112126

@@ -128,6 +142,7 @@
128142
"regnet_x_8gf",
129143
"regnet_x_16gf",
130144
"regnet_x_32gf",
145+
"regnet_z_500mf",
131146
]
132147

133148
REGNET_TEST_PRESETS = [({"name": n},) for n in REGNET_TEST_PRESET_NAMES]

0 commit comments

Comments
 (0)