Skip to content

Remove all_reduce altogether and shard the optimizer(new WR) #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 14, 2025

Conversation

vagrawal
Copy link
Contributor

@vagrawal vagrawal commented May 30, 2025

This change replaces all_reduce with reduce_scatter and shards the optimizer parameters correspondingly saving 2-2.5ms/batch in runtime over current WR(100.5 vs ~103ms in my machine). It also reduces memory for Adam parameters which were replicated on all nodes.

I also experimented with not waiting for parameter update to finish before starting next batch which seems to work fine and saves another 1ms. Just comment out TODO section to test it.

@vagrawal vagrawal changed the title Remove all_reduce altogether and shard the optimizer Remove all_reduce altogether and shard the optimizer(new WR) May 30, 2025
@KellerJordan
Copy link
Owner

Thank you very much for the record submission. I'll aim to reproduce it within the next week. It will have priority in case any other submissions come in later.

@vagrawal
Copy link
Contributor Author

vagrawal commented Jun 2, 2025

Also in my further testing, the loss in the master branch is greater than 3.28 with statistical significance(It's around 3.281). My guess is that it was caused by the change in constants in "21st record with latest torch" change like mentioned in 1.

Here are the losses for the multiple runs. Both averages are greater than 3.28 with p < 0.05

losses_upstream = [3.2836, 3.2801, 3.2798, 3.2796, 3.2785, 3.2811, 3.2806, 3.2807, 3.2815, 3.2822, 3.2808, 3.2813, 3.2806, 3.2801, 3.2799, 3.2828, 3.2831, 3.2794, 3.2806, 3.2799, 3.2794]

losses_noallreduce = [3.28, 3.2817, 3.2805, 3.2796, 3.2772, 3.2807, 3.2841, 3.2829, 3.2818, 3.2819, 3.2817, 3.2822, 3.2806]

@YouJiacheng
Copy link
Contributor

Good Job!
btw it seems that you didn't compile the DistAdam? (so it's not fused)
in addition, is autocast necessary?

@YouJiacheng
Copy link
Contributor

YouJiacheng commented Jun 3, 2025

I'm not sure if autocast will make p.grad to be bf16, so you might need to use the custom mixed precision implementation in 2.92 track to further reduce communication?

@YouJiacheng
Copy link
Contributor

YouJiacheng commented Jun 3, 2025

btw did you have any idea better than re-introducing grouping parameters by size?
I hesitated to implement reduce_scatter because it feels ugly to group parameters by size.
In addition, we should be able to use all_to_all + reduce if we group parameters by size -- so we can achieve FP32 accumulation precision with BF16 traffic. And all_to_all can be done with copy engine without using SMs.
see: pytorch/pytorch#130583

@vagrawal
Copy link
Contributor Author

vagrawal commented Jun 4, 2025

Autocast is not making p.grad to be bf16. It just allows us to remove the type_as from F.linear(x, self.weight.type_as(x)), which to my eyes is bit more ugly looking than autocast.

I did try to compile the Adam, but it didn't change the time a bit as the time is dominated by data movement across GPU, and the computation happens while the reduce_scatter is happening in parallel.

@vagrawal
Copy link
Contributor Author

vagrawal commented Jun 4, 2025

For Adam, we don't need any grouping as parameters could be split and the implementation is simpler. In fact this approach seems significantly better than ZeRO-1 as we don't need to gather the optimizer params at all. I can't find any other place which uses this idea.

For Muon, I can't think of anything other than grouping params by size.

@YouJiacheng
Copy link
Contributor

yep I don't expect that compiled Adam can be faster because of the overlapped communication, but it should save some memory haha.

@YouJiacheng
Copy link
Contributor

wdym by "this approach seems significantly better than ZeRO-1 as we don't need to gather the optimizer params at all"? IIUC you perform the all_gather in DistAdam.
oh, I guess you mean flatten all parameters into one flat tensor?

@vagrawal
Copy link
Contributor Author

vagrawal commented Jun 5, 2025

I said optimizer params(exp_avg and exp_avg_sq) not the model params. ZeRO-1 only partitions the optimizer params.

@vagrawal
Copy link
Contributor Author

vagrawal commented Jun 5, 2025

I have removed autocasts and added torch compile to Adam step, as per your comment

@KellerJordan
Copy link
Owner

I bet you're right that the change in constants would induce the extra .001 loss.

There's no reason I didn't accept til now other than that I was dreading/procrastinating having to figure out what caused the increase in loss. But now that there's a new record that came later which lowers the loss, so I can just accept both.

Accepting record

@KellerJordan KellerJordan merged commit 3e121a6 into KellerJordan:master Jul 14, 2025
@KellerJordan
Copy link
Owner

Not waiting for optimizer to finish before next step, as you mentioned, could potentially produce a new record. But it would require gathering statistical significance because it could change the forward pass, rather than being a pure systems win.

@KellerJordan
Copy link
Owner

hm... a pure systems win that wouldn't mess with the forward pass would be waiting on both optimizers (the adam and the muon) at the same time, rather than doing each one sequentially

@KellerJordan
Copy link
Owner

@vagrawal any accounts you want me to plug in X.com announcement?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants