Skip to content

Commit 60e261f

Browse files
committed
prepare for vision aided discriminator, make sure it can be saved and loaded, including optimizer and grad scaler
1 parent 11659dc commit 60e261f

File tree

2 files changed

+50
-10
lines changed

2 files changed

+50
-10
lines changed

gigagan_pytorch/gigagan_pytorch.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@ def init_(self, m):
952952

953953
@property
954954
def total_params(self):
955-
return sum([p.numel() for p in self.parameters()])
955+
return sum([p.numel() for p in self.parameters() if p.requires_grad])
956956

957957
@property
958958
def device(self):
@@ -1163,17 +1163,23 @@ class VisionAidedDiscriminator(nn.Module):
11631163
def __init__(
11641164
self,
11651165
*,
1166-
clip: OpenClipAdapter,
11671166
depth = 2,
11681167
dim_head = 64,
11691168
heads = 8,
1169+
clip: Optional[OpenClipAdapter] = None,
11701170
layer_indices = (-1, -2, -3),
11711171
conv_dim = None,
11721172
text_dim = None,
11731173
unconditional = False,
11741174
num_conv_kernels = 2
11751175
):
11761176
super().__init__()
1177+
1178+
if not exists(clip):
1179+
clip = OpenClipAdapter()
1180+
1181+
set_requires_grad_(clip, False)
1182+
11771183
self.clip = clip
11781184
dim = clip._dim_image_latent
11791185

@@ -1198,11 +1204,9 @@ def __init__(
11981204
)
11991205
]))
12001206

1201-
def parameters(self):
1202-
return [
1203-
*self.network.parameters(),
1204-
*self.to_pred.parameters()
1205-
]
1207+
@property
1208+
def total_params(self):
1209+
return sum([p.numel() for p in self.parameters() if p.requires_grad])
12061210

12071211
@beartype
12081212
def forward(
@@ -1666,6 +1670,7 @@ def __init__(
16661670
*,
16671671
generator: Union[BaseGenerator, Dict],
16681672
discriminator: Union[Discriminator, Dict],
1673+
vision_aided_discriminator: Optional[Union[VisionAidedDiscriminator, Dict]] = None,
16691674
learning_rate = 2e-4,
16701675
betas = (0.5, 0.9),
16711676
weight_decay = 0.,
@@ -1730,12 +1735,16 @@ def __init__(
17301735
if isinstance(discriminator, dict):
17311736
discriminator = Discriminator(**discriminator)
17321737

1738+
if exists(vision_aided_discriminator) and isinstance(vision_aided_discriminator, dict):
1739+
vision_aided_discriminator = VisionAidedDiscriminator(**vision_aided_discriminator)
1740+
17331741
assert isinstance(generator, generator_klass)
17341742

17351743
# use _base to designate unwrapped models
17361744

17371745
self.G = generator
17381746
self.D = discriminator
1747+
self.VD = vision_aided_discriminator
17391748

17401749
# ema
17411750

@@ -1746,8 +1755,13 @@ def __init__(
17461755

17471756
# print number of parameters
17481757

1749-
self.print(f'Generator parameters: {numerize.numerize(generator.total_params)}')
1750-
self.print(f'Discriminator parameters: {numerize.numerize(discriminator.total_params)}')
1758+
self.print(f'Generator: {numerize.numerize(generator.total_params)}')
1759+
self.print(f'Discriminator: {numerize.numerize(discriminator.total_params)}')
1760+
1761+
if exists(self.VD):
1762+
self.print(f'Vision Discriminator: {numerize.numerize(vision_aided_discriminator.total_params)}')
1763+
1764+
self.print('\n')
17511765

17521766
# text encoder
17531767

@@ -1764,6 +1778,12 @@ def __init__(
17641778

17651779
self.G, self.D, self.G_opt, self.D_opt = self.accelerator.prepare(self.G, self.D, self.G_opt, self.D_opt)
17661780

1781+
# vision aided discriminator optimizer
1782+
1783+
if exists(self.VD):
1784+
self.VD_opt = get_optimizer(self.VD.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay)
1785+
self.VD_opt = self.accelerator.prepare(self.VD_opt)
1786+
17671787
# loss related
17681788

17691789
self.discr_aux_recon_loss_weight = discr_aux_recon_loss_weight
@@ -1816,6 +1836,13 @@ def save(self, path, overwrite = True):
18161836
if exists(self.D_opt.scaler):
18171837
pkg['D_scaler'] = self.D_opt.scaler.state_dict()
18181838

1839+
if exists(self.VD):
1840+
pkg['VD'] = self.unwrapped_VD.state_dict()
1841+
pkg['VD_opt'] = self.VD_opt.state_dict()
1842+
1843+
if exists(self.VD_opt.scaler):
1844+
pkg['VD_scaler'] = self.VD_opt.scaler.state_dict()
1845+
18191846
if self.has_ema_generator:
18201847
pkg['G_ema'] = self.G_ema.state_dict()
18211848

@@ -1833,6 +1860,9 @@ def load(self, path, strict = False):
18331860
self.unwrapped_G.load_state_dict(pkg['G'], strict = strict)
18341861
self.unwrapped_D.load_state_dict(pkg['D'], strict = strict)
18351862

1863+
if exists(self.VD):
1864+
self.unwrapped_VD.load_state_dict(pkg['VD'], strict = strict)
1865+
18361866
if self.has_ema_generator:
18371867
self.G_ema.load_state_dict(pkg['G_ema'])
18381868

@@ -1846,12 +1876,18 @@ def load(self, path, strict = False):
18461876
self.G_opt.load_state_dict(pkg['G_opt'])
18471877
self.D_opt.load_state_dict(pkg['D_opt'])
18481878

1879+
if exists(self.VD):
1880+
self.VD_opt.load_state_dict(pkg['VD_opt'])
1881+
18491882
if 'G_scaler' in pkg and exists(self.G_opt.scaler):
18501883
self.G_opt.scaler.load_state_dict(pkg['G_scaler'])
18511884

18521885
if 'D_scaler' in pkg and exists(self.D_opt.scaler):
18531886
self.D_opt.scaler.load_state_dict(pkg['D_scaler'])
18541887

1888+
if 'VD_scaler' in pkg and exists(self.VD_opt.scaler):
1889+
self.VD_opt.scaler.load_state_dict(pkg['VD_scaler'])
1890+
18551891
except Exception as e:
18561892
self.print(f'unable to load optimizers {e.msg}- optimizer states will be reset')
18571893
pass
@@ -1870,6 +1906,10 @@ def unwrapped_G(self):
18701906
def unwrapped_D(self):
18711907
return self.accelerator.unwrap_model(self.D)
18721908

1909+
@property
1910+
def unwrapped_VD(self):
1911+
return self.accelerator.unwrap_model(self.VD)
1912+
18731913
def print(self, msg):
18741914
self.accelerator.print(msg)
18751915

gigagan_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.11'
1+
__version__ = '0.1.12'

0 commit comments

Comments
 (0)