@@ -361,8 +361,13 @@ def forward(
361
361
if mod .shape [0 ] != b :
362
362
mod = repeat (mod , 'b ... -> (s b) ...' , s = b // mod .shape [0 ])
363
363
364
- if kernel_mod .shape [0 ] != b :
365
- kernel_mod = repeat (kernel_mod , 'b ... -> (s b) ...' , s = b // kernel_mod .shape [0 ])
364
+ if exists (kernel_mod ):
365
+ kernel_mod_has_el = kernel_mod .numel () > 0
366
+
367
+ assert self .adaptive or not kernel_mod_has_el
368
+
369
+ if kernel_mod_has_el and kernel_mod .shape [0 ] != b :
370
+ kernel_mod = repeat (kernel_mod , 'b ... -> (s b) ...' , s = b // kernel_mod .shape [0 ])
366
371
367
372
# prepare weights for modulation
368
373
@@ -373,7 +378,7 @@ def forward(
373
378
374
379
# determine an adaptive weight and 'select' the kernel to use with softmax
375
380
376
- assert exists (kernel_mod )
381
+ assert exists (kernel_mod ) and kernel_mod . numel () > 0
377
382
378
383
kernel_attn = kernel_mod .softmax (dim = - 1 )
379
384
kernel_attn = rearrange (kernel_attn , 'b n -> b n 1 1 1 1' )
@@ -996,7 +1001,7 @@ def __init__(
996
1001
dim_out , # second conv in resnet block
997
1002
dim_kernel_mod , # second conv kernel selection
998
1003
dim_out , # to RGB conv
999
- dim_kernel_mod , # RGB conv kernel selection
1004
+ 0 , # RGB conv kernel selection
1000
1005
])
1001
1006
1002
1007
self .layers .append (nn .ModuleList ([
0 commit comments