@@ -25,6 +25,7 @@ class BlockType(Enum):
25
25
VANILLA_BLOCK = auto ()
26
26
RES_BASIC_BLOCK = auto ()
27
27
RES_BOTTLENECK_BLOCK = auto ()
28
+ RES_BOTTLENECK_LINEAR_BLOCK = auto ()
28
29
29
30
30
31
# The different possible Stems
@@ -206,8 +207,8 @@ def __init__(
206
207
bn_epsilon : float ,
207
208
bn_momentum : float ,
208
209
activation : nn .Module ,
209
- bot_mul : float ,
210
210
group_width : int ,
211
+ bot_mul : float ,
211
212
se_ratio : Optional [float ],
212
213
):
213
214
super ().__init__ ()
@@ -253,8 +254,8 @@ def __init__(
253
254
bn_epsilon : float ,
254
255
bn_momentum : float ,
255
256
activation : nn .Module ,
256
- bot_mul : float = 1.0 ,
257
257
group_width : int = 1 ,
258
+ bot_mul : float = 1.0 ,
258
259
se_ratio : Optional [float ] = None ,
259
260
):
260
261
super ().__init__ ()
@@ -273,8 +274,8 @@ def __init__(
273
274
bn_epsilon ,
274
275
bn_momentum ,
275
276
activation ,
276
- bot_mul ,
277
277
group_width ,
278
+ bot_mul ,
278
279
se_ratio ,
279
280
)
280
281
self .activation = activation
@@ -291,6 +292,41 @@ def forward(self, x, *args):
291
292
return self .activation (x )
292
293
293
294
295
+ class ResBottleneckLinearBlock (nn .Module ):
296
+ """Residual linear bottleneck block: x + F(x), F = bottleneck transform."""
297
+
298
+ def __init__ (
299
+ self ,
300
+ width_in : int ,
301
+ width_out : int ,
302
+ stride : int ,
303
+ bn_epsilon : float ,
304
+ bn_momentum : float ,
305
+ activation : nn .Module ,
306
+ group_width : int = 1 ,
307
+ bot_mul : float = 4.0 ,
308
+ se_ratio : Optional [float ] = None ,
309
+ ):
310
+ super ().__init__ ()
311
+ self .has_skip = (width_in == width_out ) and (stride == 1 )
312
+ self .f = BottleneckTransform (
313
+ width_in ,
314
+ width_out ,
315
+ stride ,
316
+ bn_epsilon ,
317
+ bn_momentum ,
318
+ activation ,
319
+ group_width ,
320
+ bot_mul ,
321
+ se_ratio ,
322
+ )
323
+
324
+ self .depth = self .f .depth
325
+
326
+ def forward (self , x ):
327
+ return x + self .f (x ) if self .has_skip else self .f (x )
328
+
329
+
294
330
class AnyStage (nn .Sequential ):
295
331
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
296
332
@@ -302,8 +338,8 @@ def __init__(
302
338
depth : int ,
303
339
block_constructor : nn .Module ,
304
340
activation : nn .Module ,
305
- bot_mul : float ,
306
341
group_width : int ,
342
+ bot_mul : float ,
307
343
params : "RegNetParams" ,
308
344
stage_index : int = 0 ,
309
345
):
@@ -318,8 +354,8 @@ def __init__(
318
354
params .bn_epsilon ,
319
355
params .bn_momentum ,
320
356
activation ,
321
- bot_mul ,
322
357
group_width ,
358
+ bot_mul ,
323
359
params .se_ratio ,
324
360
)
325
361
@@ -354,10 +390,11 @@ def __init__(
354
390
w_a : float ,
355
391
w_m : float ,
356
392
group_w : int ,
357
- stem_type : StemType = "SIMPLE_STEM_IN" ,
393
+ bot_mul : float = 1.0 ,
394
+ stem_type : StemType = StemType .SIMPLE_STEM_IN ,
358
395
stem_width : int = 32 ,
359
- block_type : BlockType = " RES_BOTTLENECK_BLOCK" ,
360
- activation_type : ActivationType = " RELU" ,
396
+ block_type : BlockType = BlockType . RES_BOTTLENECK_BLOCK ,
397
+ activation : ActivationType = ActivationType . RELU ,
361
398
use_se : bool = True ,
362
399
se_ratio : float = 0.25 ,
363
400
bn_epsilon : float = 1e-05 ,
@@ -371,9 +408,10 @@ def __init__(
371
408
self .w_a = w_a
372
409
self .w_m = w_m
373
410
self .group_w = group_w
374
- self .stem_type = StemType [stem_type ]
375
- self .block_type = BlockType [block_type ]
376
- self .activation_type = ActivationType [activation_type ]
411
+ self .bot_mul = bot_mul
412
+ self .stem_type = stem_type
413
+ self .block_type = block_type
414
+ self .activation = activation
377
415
self .stem_width = stem_width
378
416
self .use_se = use_se
379
417
self .se_ratio = se_ratio if use_se else None
@@ -403,7 +441,6 @@ def get_expanded_params(self):
403
441
404
442
QUANT = 8
405
443
STRIDE = 2
406
- BOT_MUL = 1.0
407
444
408
445
# Compute the block widths. Each stage has one unique block width
409
446
widths_cont = np .arange (self .depth ) * self .w_a + self .w_0
@@ -428,21 +465,26 @@ def get_expanded_params(self):
428
465
stage_depths = np .diff ([d for d , t in enumerate (splits ) if t ]).tolist ()
429
466
430
467
strides = [STRIDE ] * num_stages
431
- bot_muls = [BOT_MUL ] * num_stages
468
+ bot_muls = [self . bot_mul ] * num_stages
432
469
group_widths = [self .group_w ] * num_stages
433
470
434
471
# Adjust the compatibility of stage widths and group widths
435
472
stage_widths , group_widths = _adjust_widths_groups_compatibilty (
436
473
stage_widths , bot_muls , group_widths
437
474
)
438
475
439
- return zip (stage_widths , strides , stage_depths , bot_muls , group_widths )
476
+ return zip (stage_widths , strides , stage_depths , group_widths , bot_muls )
440
477
441
478
442
479
@register_model ("regnet" )
443
480
class RegNet (ClassyModel ):
444
- """Implementation of RegNet, a particular form of AnyNets
445
- See https://arxiv.org/abs/2003.13678v1"""
481
+ """Implementation of RegNet, a particular form of AnyNets.
482
+
483
+ See https://arxiv.org/abs/2003.13678 for introduction to RegNets, and details about
484
+ RegNetX and RegNetY models.
485
+
486
+ See https://arxiv.org/abs/2103.06877 for details about RegNetZ models.
487
+ """
446
488
447
489
def __init__ (self , params : RegNetParams ):
448
490
super ().__init__ ()
@@ -474,14 +516,15 @@ def __init__(self, params: RegNetParams):
474
516
BlockType .VANILLA_BLOCK : VanillaBlock ,
475
517
BlockType .RES_BASIC_BLOCK : ResBasicBlock ,
476
518
BlockType .RES_BOTTLENECK_BLOCK : ResBottleneckBlock ,
519
+ BlockType .RES_BOTTLENECK_LINEAR_BLOCK : ResBottleneckLinearBlock ,
477
520
}[params .block_type ]
478
521
479
522
current_width = params .stem_width
480
523
481
524
self .trunk_depth = 0
482
525
483
526
blocks = []
484
- for i , (width_out , stride , depth , bot_mul , group_width ) in enumerate (
527
+ for i , (width_out , stride , depth , group_width , bot_mul ) in enumerate (
485
528
params .get_expanded_params ()
486
529
):
487
530
blocks .append (
@@ -494,8 +537,8 @@ def __init__(self, params: RegNetParams):
494
537
depth ,
495
538
block_fun ,
496
539
activation ,
497
- bot_mul ,
498
540
group_width ,
541
+ bot_mul ,
499
542
params ,
500
543
stage_index = i + 1 ,
501
544
),
@@ -529,10 +572,13 @@ def from_config(cls, config: Dict[str, Any]) -> "RegNet":
529
572
w_a = config ["w_a" ],
530
573
w_m = config ["w_m" ],
531
574
group_w = config ["group_width" ],
532
- stem_type = config .get ("stem_type" , "simple_stem_in" ).upper (),
575
+ bot_mul = config .get ("bot_mul" , 1.0 ),
576
+ stem_type = StemType [config .get ("stem_type" , "simple_stem_in" ).upper ()],
533
577
stem_width = config .get ("stem_width" , 32 ),
534
- block_type = config .get ("block_type" , "res_bottleneck_block" ).upper (),
535
- activation_type = config .get ("activation_type" , "relu" ).upper (),
578
+ block_type = BlockType [
579
+ config .get ("block_type" , "res_bottleneck_block" ).upper ()
580
+ ],
581
+ activation = ActivationType [config .get ("activation" , "relu" ).upper ()],
536
582
use_se = config .get ("use_se" , True ),
537
583
se_ratio = config .get ("se_ratio" , 0.25 ),
538
584
bn_epsilon = config .get ("bn_epsilon" , 1e-05 ),
@@ -751,6 +797,45 @@ def __init__(self, **kwargs):
751
797
)
752
798
753
799
800
+ # note that RegNetZ models are trained with a convolutional head, i.e. the
801
+ # fully_connected ClassyHead with conv_planes > 0.
802
+ @register_model ("regnet_z_500mf" )
803
+ class RegNetZ500mf (_RegNet ):
804
+ def __init__ (self , ** kwargs ):
805
+ super ().__init__ (
806
+ RegNetParams (
807
+ depth = 21 ,
808
+ w_0 = 16 ,
809
+ w_a = 10.7 ,
810
+ w_m = 2.51 ,
811
+ group_w = 4 ,
812
+ bot_mul = 4.0 ,
813
+ block_type = BlockType .RES_BOTTLENECK_LINEAR_BLOCK ,
814
+ activation = ActivationType .SILU ,
815
+ ** kwargs ,
816
+ )
817
+ )
818
+
819
+
820
+ # this is supposed to be trained with a resolution of 256x256
821
+ @register_model ("regnet_z_4gf" )
822
+ class RegNetZ4gf (_RegNet ):
823
+ def __init__ (self , ** kwargs ):
824
+ super ().__init__ (
825
+ RegNetParams (
826
+ depth = 28 ,
827
+ w_0 = 48 ,
828
+ w_a = 14.5 ,
829
+ w_m = 2.226 ,
830
+ group_w = 8 ,
831
+ bot_mul = 4.0 ,
832
+ block_type = BlockType .RES_BOTTLENECK_LINEAR_BLOCK ,
833
+ activation = ActivationType .SILU ,
834
+ ** kwargs ,
835
+ )
836
+ )
837
+
838
+
754
839
# -----------------------------------------------------------------------------------
755
840
# The following models were not part of the original publication,
756
841
# (https://arxiv.org/abs/2003.13678v1), but are larger versions of the
0 commit comments