Skip to content

Risk of NaN loss in Pi0's sampling_beta implementation, especially using bfloat16 #1096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 of 2 tasks
YuhengZhi opened this issue May 11, 2025 · 0 comments
Open
1 of 2 tasks

Comments

@YuhengZhi
Copy link

System Info

Not related.

Information

  • One of the scripts in the examples/ folder of LeRobot
  • My own task or dataset (give details below)

Reproduction

In the current implementation of sample_beta in modeling_pi0.py (link):

def sample_beta(alpha, beta, bsize, device):
    gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
    gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
    return gamma1 / (gamma1 + gamma2)

When gamma1 and gamma2 are both sampled to be 0, the returned value will be NaN, causing loss to be also NaN. This happens very rarely in FP32 but when I attempted to train the pi0 model fully in bf16, it becomes much easier for gamma1 and gamma2 to be 0.

Consequently, my loss becomes NaN after roughly the first 24k samples.

This is a bug that may arise when a user tries either:
a) changing Pi0's precision fully to bfloat16 and training for ~24k samples, or
b) training pi0 in the default mixed precision for a very large number of steps.

A similar issue was brought up in the C++ Stan Math Library. They modified implementation of beta distribution to happen in the log space, which is numerically stable, as similarly done by the official openpi code (openpi's code, jax.random.beta implementation).

Expected behavior

sample_beta should never return NaN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant