Skip to content

Commit 0806433

Browse files
committed
incorporate zero centered gradient penalty on real and fake images for stability, paper out of Cornell and Brown
1 parent 55b621e commit 0806433

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,12 @@ $ accelerate launch train.py
278278
url ={https://api.semanticscholar.org/CorpusID:269214195}
279279
}
280280
```
281+
282+
```bibtex
283+
@inproceedings{Huang2025TheGI,
284+
title = {The GAN is dead; long live the GAN! A Modern GAN Baseline},
285+
author = {Yiwen Huang and Aaron Gokaslan and Volodymyr Kuleshov and James Tompkin},
286+
year = {2025},
287+
url = {https://api.semanticscholar.org/CorpusID:275405495}
288+
}
289+
```

gigagan_pytorch/gigagan_pytorch.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def gradient_penalty(
122122
outputs,
123123
grad_output_weights = None,
124124
weight = 10,
125+
center = 0.,
125126
scaler: GradScaler | None = None,
126127
eps = 1e-4
127128
):
@@ -151,7 +152,7 @@ def gradient_penalty(
151152
gradients = maybe_scaled_gradients * inv_scale
152153

153154
gradients = rearrange(gradients, 'b ... -> b (...)')
154-
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
155+
return weight * ((gradients.norm(2, dim = 1) - center) ** 2).mean()
155156

156157
# hinge gan losses
157158

@@ -2308,9 +2309,11 @@ def train_discriminator_step(
23082309
# detach output of generator, as training discriminator only
23092310

23102311
images.detach_()
2312+
images.requires_grad_()
23112313

23122314
for rgb in rgbs:
23132315
rgb.detach_()
2316+
rgb.requires_grad_()
23142317

23152318
# main divergence loss
23162319

@@ -2352,13 +2355,22 @@ def train_discriminator_step(
23522355
gp_loss = 0.
23532356

23542357
if apply_gradient_penalty:
2355-
gp_loss = gradient_penalty(
2358+
real_gp_loss = gradient_penalty(
23562359
real_images,
23572360
outputs = [real_logits, *real_multiscale_logits],
23582361
grad_output_weights = [1., *(self.multiscale_divergence_loss_weight,) * len(real_multiscale_logits)],
23592362
scaler = self.D_opt.scaler
23602363
)
23612364

2365+
fake_gp_loss = gradient_penalty(
2366+
images,
2367+
outputs = [fake_logits, *fake_multiscale_logits],
2368+
grad_output_weights = [1., *(self.multiscale_divergence_loss_weight,) * len(fake_multiscale_logits)],
2369+
scaler = self.D_opt.scaler
2370+
)
2371+
2372+
gp_loss = real_gp_loss + fake_gp_loss
2373+
23622374
if not torch.isnan(gp_loss):
23632375
total_gp_loss += (gp_loss.item() / grad_accum_every)
23642376

gigagan_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.20'
1+
__version__ = '0.3.0'

0 commit comments

Comments
 (0)