We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 154da30 commit 333d731Copy full SHA for 333d731
gigagan_pytorch/gigagan_pytorch.py
@@ -119,7 +119,8 @@ def gradient_penalty(
119
outputs,
120
grad_output_weights = None,
121
weight = 10,
122
- scaler: Optional[GradScaler] = None
+ scaler: Optional[GradScaler] = None,
123
+ eps = 1e-4
124
):
125
if not isinstance(outputs, (list, tuple)):
126
outputs = [outputs]
@@ -143,7 +144,7 @@ def gradient_penalty(
143
144
145
if exists(scaler):
146
scale = scaler.get_scale()
- inv_scale = 1. / max(scale, 1e-6)
147
+ inv_scale = 1. / max(scale, eps)
148
gradients = maybe_scaled_gradients * inv_scale
149
150
gradients = rearrange(gradients, 'b ... -> b (...)')
gigagan_pytorch/version.py
@@ -1 +1 @@
1
-__version__ = '0.2.14'
+__version__ = '0.2.15'
0 commit comments