Skip to content

Commit 333d731

Browse files
committed
tweak gradient penalty
1 parent 154da30 commit 333d731

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

gigagan_pytorch/gigagan_pytorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def gradient_penalty(
119119
outputs,
120120
grad_output_weights = None,
121121
weight = 10,
122-
scaler: Optional[GradScaler] = None
122+
scaler: Optional[GradScaler] = None,
123+
eps = 1e-4
123124
):
124125
if not isinstance(outputs, (list, tuple)):
125126
outputs = [outputs]
@@ -143,7 +144,7 @@ def gradient_penalty(
143144

144145
if exists(scaler):
145146
scale = scaler.get_scale()
146-
inv_scale = 1. / max(scale, 1e-6)
147+
inv_scale = 1. / max(scale, eps)
147148
gradients = maybe_scaled_gradients * inv_scale
148149

149150
gradients = rearrange(gradients, 'b ... -> b (...)')

gigagan_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.14'
1+
__version__ = '0.2.15'

0 commit comments

Comments
 (0)