Skip to content

Fix device assignment in get_device_name for distributed training #3303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025

Conversation

uminaty
Copy link
Contributor

@uminaty uminaty commented Apr 2, 2025

This PR updates the sentence_transformers.util.get_device_name utility to better support multi-GPU setups using tools like accelerate and torchrun.

Context

This issue in Pylate shows a small problem when not explicitly providing a device for a SentenceTransformer model, and launching a training run with accelerate or torchrun: Multiple unexpected processes with low VRAM usage remains on the same GPU.

It seems to happen because in the SentenceTransformer constructor, the get_device_name function sets cuda as the device for every rank by default, which causes multiple processes remain on cuda:0 even after accelerate distributes the model across all GPUs.

Even if the script runs fine and performance doesn't seem to be impacted, it's still VRAM on GPU 0 that could be better utilized.

This can be reproduced even when using a pure sentence_transformers script like this one: training_gooaq_lora.py. The same behavior happens when launching with:

torchrun --nproc_per_node=8 training_gooaq_lora.py

or

accelerate launch --num_processes 8 training_gooaq_lora.py

(even when the code is properly wrapped in a main() block.)

Proposed Fix

We can update the get_device_name() function to:

  • Use torch.distributed.get_rank() when distributed training is initialized.
  • Otherwise, check for LOCAL_RANK from the environment and resolve to cuda:{LOCAL_RANK}.
  • Fall back to "cpu", "mps", "npu", or "hpu" as before.

This ensures that by default, the correct GPU device is used per process, even when a model is set with device=None.

Note: This shouldn't change the behavior when launching as usual with python script.py, since if no local rank is found, it will default to cuda:0.

cc @NohTow

Fix get_device_name for distributed setup.

Fix get_device_name for distributed setup.
@tomaarsen
Copy link
Collaborator

(Feel free to ignore the CI failures, those are unrelated)

@uminaty
Copy link
Contributor Author

uminaty commented Apr 3, 2025

Thanks for the review! Let me know if anything else is needed

@tomaarsen
Copy link
Collaborator

I think we're all set! Thank you for tackling this, I think this is a really solid default that I definitely should have already implemented ages ago.

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 07f53c5 into UKPLab:master Apr 3, 2025
1 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants