Skip to content

[Feature] Support Tensor Parallelism and Weight Slicing for Lora #4274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 19, 2025

Conversation

aoshen524
Copy link
Contributor

@aoshen524 aoshen524 commented Mar 10, 2025

Motivation

#3414 reports issues regarding limited model support compared to test_generation_models.py. This PR introduces tensor parallelism and weight slicing for LoRA, alongside additional improvements to testing and functionality.

Modifications

  • Implemented tensor parallelism support in LoRA, allowing efficient distribution of computations across multiple devices.
  • Introduced LoRA weight slicing and refactor memory pool to facilitate distributed inference, optimizing memory usage and performance.
  • Remove useless code for CPU-GPU weight transmission.
  • Add error handling when num available gpu < num needed

checklist:

  • Remove tensor.contiguous() used in GPU

Throughput for LLaMA 7B with Triton Backend in 4 x 3090-24GB

LoRA Config: wissdw/4r2a_llama_hf

TP Size = 1 TP Size = 2 TP Size = 4
Use LoRA 800 tok/s 725 tok/s 717 tok/s
No LoRA 793 tok/s 783 tok/s 765 tok/s

Local CI test result for test/srt/models/lora/test_lora_tp.py
image

- Remove load_to_gpu and offload_from_gpu methods from LoRALayer and LoRAAdapter classes
- Simplify weight initialization and management for LoRA layers
- This change reduces code complexity and removes unnecessary functionality
- Implement tensor parallelism for LoRA weights in column-major format
- Add logic to slice LoRA weights for row-parallel modules
- Update memory pool initialization to handle row-parallel modules
- Modify weight loading to accommodate row-parallelism
- Implement slice_lora_a_weights and slice_lora_b_weights methods for various layers
- Add support for splitting LoRA weights across multiple GPUs
- Improve weight handling for VocabParallelEmbeddingWithLoRA
- Enhance ColumnParallelLinearWithLoRA and related classes for LoRA integration
- Update QKVParallelLinearWithLoRA for better weight management
- Modify RowParallelLinearWithLoRA for efficient weight slicing
…hardware without peer to peer communication.
@Fridge003
Copy link
Collaborator

Fridge003 commented Mar 11, 2025

Great work! Also, to ensure the performance of lora after tp, please paste some benchmark results before/after enabling tp. You can refer to the benchmark in #3161 for example.

@aoshen524
Copy link
Contributor Author

Great work! Also, to ensure the performance of lora after tp, please paste some benchmark results before/after enabling tp. You can refer to the benchmark in #3161 for example.

Good advice. But no extra communication kernel launch or used when supporting lora tp. Still recommend to do it?

@Fridge003
Copy link
Collaborator

Great work! Also, to ensure the performance of lora after tp, please paste some benchmark results before/after enabling tp. You can refer to the benchmark in #3161 for example.

Good advice. But no extra communication kernel launch or used when supporting lora tp. Still recommend to do it?

Yes, just make sure lora with tp is not too slow.

@aoshen524
Copy link
Contributor Author

Great work! Also, to ensure the performance of lora after tp, please paste some benchmark results before/after enabling tp. You can refer to the benchmark in #3161 for example.

Good advice. But no extra communication kernel launch or used when supporting lora tp. Still recommend to do it?

Yes, just make sure lora with tp is not too slow.

sure

- Add checks for available GPUs before setting device
- Raise informative errors for invalid GPU IDs or lack of CUDA support
- Refactor CUDA device count retrieval into a separate function
- Update GPU memory retrieval to use the new device count function
- Rename test file from test_lora_backend_tensor_parallel.py to test_lora_tp.py
- Remove 'backend' parameter from test functions, focusing on Triton backend
- Introduce 'tp_size' parameter to test different tensor parallel configurations
- Update test suite to reflect the new file name
@aoshen524 aoshen524 requested a review from ByronHsu as a code owner March 15, 2025 14:37
Copy link
Collaborator

@Fridge003 Fridge003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Fridge003 Fridge003 merged commit 588865f into sgl-project:main Mar 19, 2025
34 of 36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants