Skip to content

Commit 588865f

Browse files
aoshen524ShenAo1111Fridge003
authored
[Feature] Support Tensor Parallelism and Weight Slicing for Lora (#4274)
Co-authored-by: ShenAo1111 <[email protected]> Co-authored-by: Baizhou Zhang <[email protected]>
1 parent 3196999 commit 588865f

File tree

13 files changed

+528
-103
lines changed

13 files changed

+528
-103
lines changed

.github/workflows/pr-test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ jobs:
127127
cd test/srt
128128
python3 test_mla_tp.py
129129
130+
- name: Test lora tensor parallelism (TP=2)
131+
timeout-minutes: 10
132+
run: |
133+
cd test/srt/models/lora
134+
python3 test_lora_tp.py
135+
130136
performance-test-1-gpu-part-1:
131137
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
132138
github.event.pull_request.draft == false

benchmark/lora/launch_server.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def launch_server(args):
2222
cmd += f"--disable-radix --disable-cuda-graph "
2323
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
2424
cmd += f"--max-running-requests {args.max_running_requests} "
25-
cmd += f"--lora-backend {args.lora_backend}"
25+
cmd += f"--lora-backend {args.lora_backend} "
26+
cmd += f"--tp-size {args.tp_size} "
27+
if args.disable_custom_all_reduce:
28+
cmd += "--disable-custom-all-reduce"
2629
print(cmd)
2730
os.system(cmd)
2831

@@ -48,6 +51,18 @@ def launch_server(args):
4851
type=str,
4952
default="triton",
5053
)
54+
parser.add_argument(
55+
"--tp-size",
56+
type=int,
57+
default=1,
58+
help="Tensor parallel size for distributed inference",
59+
)
60+
# disable_custom_all_reduce
61+
parser.add_argument(
62+
"--disable-custom-all-reduce",
63+
action="store_true",
64+
help="Disable custom all reduce when device does not support p2p communication",
65+
)
5166
args = parser.parse_args()
5267

5368
launch_server(args)

python/sglang/srt/layers/linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,8 @@ def __init__(
782782
else:
783783
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
784784
self.num_kv_head_replicas = 1
785+
self.q_proj_shard_size = self.num_heads * self.head_size
786+
self.kv_proj_shard_size = self.num_kv_heads * self.head_size
785787
input_size = self.hidden_size
786788
output_size = (
787789
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size

python/sglang/srt/lora/layers.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List, Tuple
2+
13
import torch
24
from torch import nn
35

@@ -38,8 +40,22 @@ def forward(self, x: torch.Tensor):
3840
def set_lora_info(self, *args):
3941
pass
4042

43+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
44+
pass
45+
46+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
47+
pass
48+
4149

4250
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
51+
"""
52+
Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
53+
54+
Note: The current version does not yet implement the LoRA functionality.
55+
This class behaves exactly the same as the base VocabParallelEmbedding.
56+
Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
57+
"""
58+
4359
def __init__(
4460
self,
4561
base_layer: VocabParallelEmbedding,
@@ -101,6 +117,16 @@ def forward(self, input_: torch.Tensor):
101117
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
102118
return output, output_bias
103119

120+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
121+
return A
122+
123+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
124+
shard_size = self.base_layer.output_partition_sizes[0]
125+
start_idx = tp_rank * shard_size
126+
end_idx = (tp_rank + 1) * shard_size
127+
B = B[start_idx:end_idx, :]
128+
return B
129+
104130

105131
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
106132
def __init__(
@@ -120,6 +146,7 @@ def set_lora_info(
120146
self.set_lora = True
121147
self.A_buffer_gate_up = A_buffer
122148
if self.lora_backend.fuse_stacked_lora_b:
149+
# TODO: avoid using contiguous() in GPU.
123150
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
124151
self.B_buffer_gate_up = torch.cat(
125152
(B_buffer[0], B_buffer[1]), dim=-2
@@ -142,6 +169,16 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
142169
else base_output + lora_output * self.scaling
143170
)
144171

172+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
173+
return A
174+
175+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
176+
# Since the outputs for both gate and up are identical, we use a random one.
177+
shard_size = self.base_layer.output_partition_sizes[0]
178+
start_idx = tp_rank * shard_size
179+
end_idx = (tp_rank + 1) * shard_size
180+
return B[:, start_idx:end_idx, :]
181+
145182

146183
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
147184
def init__(
@@ -210,6 +247,27 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
210247
else base_output + lora_output * self.scaling
211248
)
212249

250+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
251+
return A
252+
253+
def slice_lora_b_weights(
254+
self, B: List[torch.Tensor], tp_rank: int
255+
) -> Tuple[torch.Tensor, torch.Tensor]:
256+
B_q, B_kv = B
257+
base_layer = self.base_layer
258+
q_proj_shard_size = base_layer.q_proj_shard_size
259+
kv_proj_shard_size = base_layer.kv_proj_shard_size
260+
num_kv_head_replicas = base_layer.num_kv_head_replicas
261+
262+
q_start_idx = q_proj_shard_size * tp_rank
263+
q_end_idx = q_start_idx + q_proj_shard_size
264+
265+
kv_shard_id = tp_rank // num_kv_head_replicas
266+
kv_start_idx = kv_proj_shard_size * kv_shard_id
267+
kv_end_idx = kv_start_idx + kv_proj_shard_size
268+
269+
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
270+
213271

214272
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
215273
def __init__(
@@ -274,6 +332,16 @@ def forward(self, input_: torch.Tensor):
274332
output_bias = self.base_layer.bias
275333
return output, output_bias
276334

335+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
336+
shard_size = self.base_layer.input_size_per_partition
337+
start_idx = tp_rank * shard_size
338+
end_idx = (tp_rank + 1) * shard_size
339+
A = A[:, start_idx:end_idx].contiguous()
340+
return A
341+
342+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
343+
return B
344+
277345

278346
def get_lora_layer(
279347
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend

python/sglang/srt/lora/lora.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,9 @@ def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
3939
super().__init__()
4040
self.config: LoRAConfig = config
4141
self.base_hf_config: AutoConfig = base_hf_config
42-
self.weights: Dict[str, torch.Tensor] = {}
43-
self.weight_gpu: Dict[str, torch.Tensor] = {}
44-
45-
def load_to_gpu(self):
46-
for name, weight in self.weights.items():
47-
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
4842

49-
def offload_from_gpu(self):
50-
for name, weight in self.weights.items():
51-
self.weight_gpu[name] = None
43+
# lora weights in cpu. The weights are loaded from checkpoint.
44+
self.weights: Dict[str, torch.Tensor] = {}
5245

5346

5447
class LoRAAdapter(nn.Module):
@@ -77,19 +70,6 @@ def __init__(
7770
)
7871

7972
self.weights: Dict[str, torch.Tensor] = {}
80-
self.weights_gpu: Dict[str, torch.Tensor] = {}
81-
82-
def load_to_gpu(self):
83-
for name, weight in self.weights.items():
84-
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
85-
for layer in self.layers:
86-
layer.load_to_gpu()
87-
88-
def offload_from_gpu(self):
89-
for name, weight in self.weights.items():
90-
self.weights_gpu[name] = None
91-
for layer in self.layers:
92-
layer.offload_from_gpu()
9373

9474
# initialize the LoRA weights to cpu
9575
def initialize_weights(self):

python/sglang/srt/lora/lora_manager.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sglang.srt.configs.load_config import LoadConfig
2424
from sglang.srt.hf_transformers_utils import AutoConfig
2525
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
26-
from sglang.srt.lora.layers import get_lora_layer
26+
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
2727
from sglang.srt.lora.lora import LoRAAdapter
2828
from sglang.srt.lora.lora_config import LoRAConfig
2929
from sglang.srt.lora.mem_pool import LoRAMemoryPool
@@ -51,13 +51,18 @@ def __init__(
5151
load_config: LoadConfig,
5252
dtype: torch.dtype,
5353
lora_backend: str = "triton",
54+
tp_size: int = 1,
55+
tp_rank: int = 0,
5456
):
5557
self.base_model: torch.nn.Module = base_model
5658
self.lora_paths: Dict[str, str] = lora_paths
5759
self.base_hf_config: AutoConfig = base_hf_config
5860
self.max_loras_per_batch: int = max_loras_per_batch
5961
self.load_config: LoadConfig = load_config
6062
self.dtype: torch.dtype = dtype
63+
self.device: torch.device = next(self.base_model.parameters()).device
64+
self.tp_size: int = tp_size
65+
self.tp_rank: int = tp_rank
6166

6267
# LoRA backend for running sgemm kernels
6368
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -110,7 +115,13 @@ def init_loras(self):
110115
def init_lora_memory_pool(self):
111116
# Initialize memory pool
112117
self.memory_pool = LoRAMemoryPool(
113-
self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
118+
self.base_hf_config,
119+
self.max_loras_per_batch,
120+
self.max_lora_dim,
121+
self.dtype,
122+
self.tp_size,
123+
self.tp_rank,
124+
self.lora_modules,
114125
)
115126

116127
# Initialize target lora modules in memory pool
@@ -131,12 +142,12 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
131142
seg_lens = (
132143
forward_batch.extend_seq_lens
133144
if forward_batch.forward_mode.is_extend()
134-
else torch.ones(bs, device="cuda")
145+
else torch.ones(bs, device=self.device)
135146
)
136-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
147+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
137148
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
138149
max_len = int(torch.max(seg_lens))
139-
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
150+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
140151
for i, lora_path in enumerate(forward_batch.lora_paths):
141152
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
142153

@@ -150,22 +161,32 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
150161
self.lora_backend.set_batch_info(batch_info)
151162

152163
# call set_lora_info for each lora modules
153-
for module_name, module in self.lora_modules:
154-
layer_id = get_layer_id(module_name)
155-
if "qkv_proj" not in module_name:
156-
weight_name = get_weight_name(
157-
module_name, self.lora_weight_names, LoRAType.LORA_A
158-
)
159-
module.set_lora_info(
160-
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
161-
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
162-
)
163-
else:
164-
module.set_lora_info(
165-
self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
166-
self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
167-
self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
168-
)
164+
for layer_id, modules in self.lora_modules.items():
165+
for module_name, module in modules:
166+
if "qkv_proj" in module_name:
167+
module.set_lora_info(
168+
self.memory_pool.get_tensor(
169+
"qkv_proj", layer_id, LoRAType.LORA_A
170+
),
171+
self.memory_pool.get_tensor(
172+
"q_proj", layer_id, LoRAType.LORA_B
173+
),
174+
self.memory_pool.get_tensor(
175+
"kv_proj", layer_id, LoRAType.LORA_B
176+
),
177+
)
178+
else:
179+
weight_name = get_weight_name(
180+
module_name, self.lora_weight_names, LoRAType.LORA_A
181+
)
182+
module.set_lora_info(
183+
self.memory_pool.get_tensor(
184+
weight_name, layer_id, LoRAType.LORA_A
185+
),
186+
self.memory_pool.get_tensor(
187+
weight_name, layer_id, LoRAType.LORA_B
188+
),
189+
)
169190

170191
def set_lora_module(self, module_name, module):
171192
lora_module = get_lora_layer(
@@ -182,10 +203,13 @@ def convert_to_lora_layers(self):
182203
)
183204

184205
# Monkey patch to use the LoRA version layers
185-
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
206+
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
207+
i: [] for i in range(self.base_hf_config.num_hidden_layers)
208+
}
186209
for module_name, module in self.base_model.named_modules():
187210
# The module should be converted if it is included in target_names
188211
if module_name.split(".")[-1] in customized_target_names:
189-
self.lora_modules.append(
212+
layer_id = get_layer_id(module_name)
213+
self.lora_modules[layer_id].append(
190214
(module_name, self.set_lora_module(module_name, module))
191215
)

0 commit comments

Comments
 (0)