Skip to content

when delaying optimizer creation only prepare the model #39152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2025

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Jul 1, 2025

What does this PR do?

Axolotl's CI caught a regression when we tried to upgrade to latest transformers. https://github.com/axolotl-ai-cloud/axolotl/actions/runs/15962262932/job/45016550543

PR #36132 introduced a regression breaking FSDP w llama

stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward                                                                                                          
stderr: [rank0]:     inputs_embeds = self.embed_tokens(input_ids)                                                                                                                                                                                              
stderr: [rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                              
stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl                                                                                                               
stderr: [rank0]:     return self._call_impl(*args, **kwargs)                                                                                                                                                                                                   
stderr: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                   
stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl                                                                                                                       
stderr: [rank0]:     return forward_call(*args, **kwargs)                                                                                                                                                                                                      
stderr: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                      
stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 190, in forward                                                                                                                           
stderr: [rank0]:     return F.embedding(                                                                                                                                                                                                                       
stderr: [rank0]:            ^^^^^^^^^^^^                                                                                                                                                                                                                       stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/functional.py", line 2551, in embedding                                                                                                                            
stderr: [rank0]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)                                                                                                                                                            stderr: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                            
stderr: [rank0]: RuntimeError: Output 0 of ViewBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.  

and FSDP+DPO+qwen

stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward                                                                                                          
stderr: [rank0]:     inputs_embeds = self.embed_tokens(input_ids)                                                                                                                                                                                              
stderr: [rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                              
stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl                                                                                                               
stderr: [rank0]:     return self._call_impl(*args, **kwargs)                                                                                                                                                                                                   
stderr: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                   
stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl                                                                                                                       
stderr: [rank0]:     return forward_call(*args, **kwargs)                                                                                                                                                                                                      
stderr: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                      
stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 190, in forward                                                                                                                           
stderr: [rank0]:     return F.embedding(                                                                                                                                                                                                                       
stderr: [rank0]:            ^^^^^^^^^^^^                                                                                                                                                                                                                       stderr: [rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/functional.py", line 2551, in embedding                                                                                                                            
stderr: [rank0]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)                                                                                                                                                            stderr: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                            
stderr: [rank0]: RuntimeError: Output 0 of ViewBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.  

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif added the for patch Tag issues / labels that should be included in the next patch label Jul 1, 2025
@Cyrilvallez
Copy link
Member

cc @SunMarc

@@ -2357,7 +2357,7 @@ def _inner_training_loop(
model = self.accelerator.prepare(self.model)
else:
if delay_optimizer_creation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @IlyasMoutawwakil as you wanted to remove this 👀

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this fixes it too ! I honestly don't understand delay_optimizer_creation, like delay until when and why ? 😅 might make sense to explain it somewhere in the trainer

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you see why I removed it, is because currently we do create the optimizer here, and we need to prepare the fsdp model as well (otherwise fsdp fails), so the two branches of the if statement become the same

@winglian winglian mentioned this pull request Jul 1, 2025
5 tasks
@ArthurZucker ArthurZucker merged commit 8178c43 into huggingface:main Jul 3, 2025
18 checks passed
@SunMarc
Copy link
Member

SunMarc commented Jul 3, 2025

cc @kmehant, if you can explain the change you tried to do, that would be helpful !

@kmehant
Copy link
Contributor

kmehant commented Jul 3, 2025

Hi @SunMarc thanks for looping me in! Appreciate it.

Ideally this block of code

if delay_optimizer_creation:
if use_accelerator_prepare:
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if self.accelerator.mixed_precision != "fp8":
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
should be doing model preparation using accelerate even for FSDP and TP (if you remember in the older version we TPlize the model in accelerate prepare which is not the latest cases so we are good) cases so that the model is wrapped and then the block does the optimizer creation taking in the accelerate prepared model parameters. After which it comes to the current block
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
if delay_optimizer_creation:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
which is being modified in the code where a prepare again for the model is not needed rather only for the optimizer since the previous created optimizer didn't undergo accelerate prepare. That was the rationale behind this change. Ideally instead of the change made in this PR, I think we should have simply modified

self.model = self.accelerator.prepare(self.model)

to

model = self.accelerator.prepare(self.model)

here -

self.model = self.accelerator.prepare(self.model)

OR

We can also go back to the older code too, since TPlizing the model is removed from accelerate prepare step which works as well.

model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)

for FSDP case. I can help in contribution if needed.

Nonetheless I +1 to @IlyasMoutawwakil to remove this all together since its always been a confusing parameter to me :)

cc: @ArthurZucker @winglian

@kmehant
Copy link
Contributor

kmehant commented Jul 3, 2025

@SunMarc @IlyasMoutawwakil @ArthurZucker

This is much more correct fix for this bug - PR: #39177. The current PR breaks TP trainings (since prepare is not needed for TP and enforcing prepare would lead to DDP setting which fails). The aforementioned PR fixes for both FSDP and TP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
for patch Tag issues / labels that should be included in the next patch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants