You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When gamma1 and gamma2 are both sampled to be 0, the returned value will be NaN, causing loss to be also NaN. This happens very rarely in FP32 but when I attempted to train the pi0 model fully in bf16, it becomes much easier for gamma1 and gamma2 to be 0.
Consequently, my loss becomes NaN after roughly the first 24k samples.
This is a bug that may arise when a user tries either:
a) changing Pi0's precision fully to bfloat16 and training for ~24k samples, or
b) training pi0 in the default mixed precision for a very large number of steps.
A similar issue was brought up in the C++ Stan Math Library. They modified implementation of beta distribution to happen in the log space, which is numerically stable, as similarly done by the official openpi code (openpi's code, jax.random.beta implementation).
Expected behavior
sample_beta should never return NaN.
The text was updated successfully, but these errors were encountered:
System Info
Information
Reproduction
In the current implementation of
sample_beta
inmodeling_pi0.py
(link):When
gamma1
andgamma2
are both sampled to be 0, the returned value will beNaN
, causing loss to be alsoNaN
. This happens very rarely in FP32 but when I attempted to train the pi0 model fully inbf16
, it becomes much easier forgamma1
andgamma2
to be 0.Consequently, my loss becomes
NaN
after roughly the first 24k samples.This is a bug that may arise when a user tries either:
a) changing Pi0's precision fully to
bfloat16
and training for ~24k samples, orb) training pi0 in the default mixed precision for a very large number of steps.
A similar issue was brought up in the C++ Stan Math Library. They modified implementation of beta distribution to happen in the log space, which is numerically stable, as similarly done by the official openpi code (openpi's code,
jax.random.beta
implementation).Expected behavior
sample_beta
should never returnNaN
.The text was updated successfully, but these errors were encountered: