Skip to content

Allow clipping the dynamic loss scale to a minimum value #3051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flax/training/dynamic_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,17 @@ def loss_fn(p):
be increased (default: 2000).
fin_steps: indicates how many gradient steps in a row have been finite.
scale: the current scale by which the loss is multiplied.
minimum_scale: the minimum value that the scale can take (default: the
smallest positive number representable in floating point).
"""
growth_factor: float = struct.field(pytree_node=False, default=2.0)
backoff_factor: float = struct.field(pytree_node=False, default=0.5)
growth_interval: int = struct.field(pytree_node=False, default=2000)
fin_steps: Array = 0
scale: Array = 65536.0
minimum_scale: Optional[float] = struct.field(
pytree_node=False, default=jnp.finfo(jnp.float32).tiny
)

def value_and_grad(self, fun: Callable[..., Any],
argnums: Union[int, Sequence[int]] = 0,
Expand Down Expand Up @@ -137,6 +142,8 @@ def grad_fn_wrapper(*args):
jnp.minimum(self.scale * self.growth_factor, jnp.finfo(jnp.float32).max),
self.scale)
inf_scale = self.scale * self.backoff_factor
if self.minimum_scale is not None:
inf_scale = jnp.maximum(inf_scale, self.minimum_scale)
new_scale = jnp.where(finite, fin_scale, inf_scale)
new_fin_steps = jnp.where(grow | (~finite), 0, self.fin_steps + 1)

Expand Down