Skip to content

Commit 1e7786a

Browse files
committed
address extraneous dim kernel mod for rgb adaptive conv, thanks to @inspirit again #45
1 parent 3ce4c52 commit 1e7786a

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

gigagan_pytorch/gigagan_pytorch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,13 @@ def forward(
361361
if mod.shape[0] != b:
362362
mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0])
363363

364-
if kernel_mod.shape[0] != b:
365-
kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])
364+
if exists(kernel_mod):
365+
kernel_mod_has_el = kernel_mod.numel() > 0
366+
367+
assert self.adaptive or not kernel_mod_has_el
368+
369+
if kernel_mod_has_el and kernel_mod.shape[0] != b:
370+
kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])
366371

367372
# prepare weights for modulation
368373

@@ -373,7 +378,7 @@ def forward(
373378

374379
# determine an adaptive weight and 'select' the kernel to use with softmax
375380

376-
assert exists(kernel_mod)
381+
assert exists(kernel_mod) and kernel_mod.numel() > 0
377382

378383
kernel_attn = kernel_mod.softmax(dim = -1)
379384
kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1 1')
@@ -996,7 +1001,7 @@ def __init__(
9961001
dim_out, # second conv in resnet block
9971002
dim_kernel_mod, # second conv kernel selection
9981003
dim_out, # to RGB conv
999-
dim_kernel_mod, # RGB conv kernel selection
1004+
0, # RGB conv kernel selection
10001005
])
10011006

10021007
self.layers.append(nn.ModuleList([

gigagan_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.18'
1+
__version__ = '0.2.19'

0 commit comments

Comments
 (0)