We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2868cdb commit 154da30Copy full SHA for 154da30
gigagan_pytorch/gigagan_pytorch.py
@@ -2135,10 +2135,14 @@ def train_discriminator_step(
2135
all_real_images = []
2136
2137
self.G.train()
2138
- self.D.train()
2139
+ self.D.train()
2140
self.D_opt.zero_grad()
2141
2142
+ if self.need_vision_aided_discriminator:
2143
+ self.VD.train()
2144
+ self.VD_opt.zero_grad()
2145
+
2146
for _ in range(grad_accum_every):
2147
2148
if self.unconditional:
@@ -2344,6 +2348,9 @@ def train_discriminator_step(
2344
2348
2345
2349
self.D_opt.step()
2346
2350
2351
2352
+ self.VD_opt.step()
2353
2347
2354
return TrainDiscrLosses(
2355
total_divergence,
2356
total_multiscale_divergence,
gigagan_pytorch/version.py
@@ -1 +1 @@
1
-__version__ = '0.2.12'
+__version__ = '0.2.14'
0 commit comments