@@ -952,7 +952,7 @@ def init_(self, m):
952
952
953
953
@property
954
954
def total_params (self ):
955
- return sum ([p .numel () for p in self .parameters ()])
955
+ return sum ([p .numel () for p in self .parameters () if p . requires_grad ])
956
956
957
957
@property
958
958
def device (self ):
@@ -1163,17 +1163,23 @@ class VisionAidedDiscriminator(nn.Module):
1163
1163
def __init__ (
1164
1164
self ,
1165
1165
* ,
1166
- clip : OpenClipAdapter ,
1167
1166
depth = 2 ,
1168
1167
dim_head = 64 ,
1169
1168
heads = 8 ,
1169
+ clip : Optional [OpenClipAdapter ] = None ,
1170
1170
layer_indices = (- 1 , - 2 , - 3 ),
1171
1171
conv_dim = None ,
1172
1172
text_dim = None ,
1173
1173
unconditional = False ,
1174
1174
num_conv_kernels = 2
1175
1175
):
1176
1176
super ().__init__ ()
1177
+
1178
+ if not exists (clip ):
1179
+ clip = OpenClipAdapter ()
1180
+
1181
+ set_requires_grad_ (clip , False )
1182
+
1177
1183
self .clip = clip
1178
1184
dim = clip ._dim_image_latent
1179
1185
@@ -1198,11 +1204,9 @@ def __init__(
1198
1204
)
1199
1205
]))
1200
1206
1201
- def parameters (self ):
1202
- return [
1203
- * self .network .parameters (),
1204
- * self .to_pred .parameters ()
1205
- ]
1207
+ @property
1208
+ def total_params (self ):
1209
+ return sum ([p .numel () for p in self .parameters () if p .requires_grad ])
1206
1210
1207
1211
@beartype
1208
1212
def forward (
@@ -1666,6 +1670,7 @@ def __init__(
1666
1670
* ,
1667
1671
generator : Union [BaseGenerator , Dict ],
1668
1672
discriminator : Union [Discriminator , Dict ],
1673
+ vision_aided_discriminator : Optional [Union [VisionAidedDiscriminator , Dict ]] = None ,
1669
1674
learning_rate = 2e-4 ,
1670
1675
betas = (0.5 , 0.9 ),
1671
1676
weight_decay = 0. ,
@@ -1730,12 +1735,16 @@ def __init__(
1730
1735
if isinstance (discriminator , dict ):
1731
1736
discriminator = Discriminator (** discriminator )
1732
1737
1738
+ if exists (vision_aided_discriminator ) and isinstance (vision_aided_discriminator , dict ):
1739
+ vision_aided_discriminator = VisionAidedDiscriminator (** vision_aided_discriminator )
1740
+
1733
1741
assert isinstance (generator , generator_klass )
1734
1742
1735
1743
# use _base to designate unwrapped models
1736
1744
1737
1745
self .G = generator
1738
1746
self .D = discriminator
1747
+ self .VD = vision_aided_discriminator
1739
1748
1740
1749
# ema
1741
1750
@@ -1746,8 +1755,13 @@ def __init__(
1746
1755
1747
1756
# print number of parameters
1748
1757
1749
- self .print (f'Generator parameters: { numerize .numerize (generator .total_params )} ' )
1750
- self .print (f'Discriminator parameters: { numerize .numerize (discriminator .total_params )} ' )
1758
+ self .print (f'Generator: { numerize .numerize (generator .total_params )} ' )
1759
+ self .print (f'Discriminator: { numerize .numerize (discriminator .total_params )} ' )
1760
+
1761
+ if exists (self .VD ):
1762
+ self .print (f'Vision Discriminator: { numerize .numerize (vision_aided_discriminator .total_params )} ' )
1763
+
1764
+ self .print ('\n ' )
1751
1765
1752
1766
# text encoder
1753
1767
@@ -1764,6 +1778,12 @@ def __init__(
1764
1778
1765
1779
self .G , self .D , self .G_opt , self .D_opt = self .accelerator .prepare (self .G , self .D , self .G_opt , self .D_opt )
1766
1780
1781
+ # vision aided discriminator optimizer
1782
+
1783
+ if exists (self .VD ):
1784
+ self .VD_opt = get_optimizer (self .VD .parameters (), lr = learning_rate , betas = betas , weight_decay = weight_decay )
1785
+ self .VD_opt = self .accelerator .prepare (self .VD_opt )
1786
+
1767
1787
# loss related
1768
1788
1769
1789
self .discr_aux_recon_loss_weight = discr_aux_recon_loss_weight
@@ -1816,6 +1836,13 @@ def save(self, path, overwrite = True):
1816
1836
if exists (self .D_opt .scaler ):
1817
1837
pkg ['D_scaler' ] = self .D_opt .scaler .state_dict ()
1818
1838
1839
+ if exists (self .VD ):
1840
+ pkg ['VD' ] = self .unwrapped_VD .state_dict ()
1841
+ pkg ['VD_opt' ] = self .VD_opt .state_dict ()
1842
+
1843
+ if exists (self .VD_opt .scaler ):
1844
+ pkg ['VD_scaler' ] = self .VD_opt .scaler .state_dict ()
1845
+
1819
1846
if self .has_ema_generator :
1820
1847
pkg ['G_ema' ] = self .G_ema .state_dict ()
1821
1848
@@ -1833,6 +1860,9 @@ def load(self, path, strict = False):
1833
1860
self .unwrapped_G .load_state_dict (pkg ['G' ], strict = strict )
1834
1861
self .unwrapped_D .load_state_dict (pkg ['D' ], strict = strict )
1835
1862
1863
+ if exists (self .VD ):
1864
+ self .unwrapped_VD .load_state_dict (pkg ['VD' ], strict = strict )
1865
+
1836
1866
if self .has_ema_generator :
1837
1867
self .G_ema .load_state_dict (pkg ['G_ema' ])
1838
1868
@@ -1846,12 +1876,18 @@ def load(self, path, strict = False):
1846
1876
self .G_opt .load_state_dict (pkg ['G_opt' ])
1847
1877
self .D_opt .load_state_dict (pkg ['D_opt' ])
1848
1878
1879
+ if exists (self .VD ):
1880
+ self .VD_opt .load_state_dict (pkg ['VD_opt' ])
1881
+
1849
1882
if 'G_scaler' in pkg and exists (self .G_opt .scaler ):
1850
1883
self .G_opt .scaler .load_state_dict (pkg ['G_scaler' ])
1851
1884
1852
1885
if 'D_scaler' in pkg and exists (self .D_opt .scaler ):
1853
1886
self .D_opt .scaler .load_state_dict (pkg ['D_scaler' ])
1854
1887
1888
+ if 'VD_scaler' in pkg and exists (self .VD_opt .scaler ):
1889
+ self .VD_opt .scaler .load_state_dict (pkg ['VD_scaler' ])
1890
+
1855
1891
except Exception as e :
1856
1892
self .print (f'unable to load optimizers { e .msg } - optimizer states will be reset' )
1857
1893
pass
@@ -1870,6 +1906,10 @@ def unwrapped_G(self):
1870
1906
def unwrapped_D (self ):
1871
1907
return self .accelerator .unwrap_model (self .D )
1872
1908
1909
+ @property
1910
+ def unwrapped_VD (self ):
1911
+ return self .accelerator .unwrap_model (self .VD )
1912
+
1873
1913
def print (self , msg ):
1874
1914
self .accelerator .print (msg )
1875
1915
0 commit comments