Skip to content

Commit c9ff88a

Browse files
enable cuda graph for lora
Co-authored-by: Beichen Ma <[email protected]>
1 parent d2b8d0b commit c9ff88a

File tree

8 files changed

+299
-43
lines changed

8 files changed

+299
-43
lines changed

benchmark/lora/launch_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def launch_server(args):
1919
for i in range(NUM_LORAS):
2020
lora_name = f"lora{i}"
2121
cmd += f"{lora_name}={lora_path} "
22-
cmd += f"--disable-radix --disable-cuda-graph "
22+
cmd += f"--disable-radix "
2323
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
2424
cmd += f"--max-running-requests {args.max_running_requests} "
2525
cmd += f"--lora-backend {args.lora_backend} "

docs/backend/server_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Please consult the documentation below to learn more about the parameters you ma
133133

134134
## LoRA
135135

136-
* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supported with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929).
136+
* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `radix_attention` is not supported with this option so you need to disable it manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929).
137137
* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model.
138138
* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`.
139139

python/sglang/srt/lora/layers.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(
127127
lora_backend: BaseLoRABackend,
128128
) -> None:
129129
super().__init__(base_layer, lora_backend)
130+
self.B_buffer_gate_up = None
130131

131132
def set_lora_info(
132133
self,
@@ -138,9 +139,20 @@ def set_lora_info(
138139
if self.lora_backend.fuse_stacked_lora_b:
139140
# TODO: avoid using contiguous() in GPU.
140141
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
141-
self.B_buffer_gate_up = torch.cat(
142-
(B_buffer[0], B_buffer[1]), dim=-2
143-
).contiguous()
142+
if self.B_buffer_gate_up is None:
143+
self.B_buffer_gate_up = torch.empty(
144+
(
145+
B_buffer[0].shape[0],
146+
2 * B_buffer[0].shape[1],
147+
B_buffer[0].shape[2],
148+
),
149+
dtype=B_buffer[0].dtype,
150+
device=B_buffer[0].device,
151+
).contiguous()
152+
# TODO: avoid using contiguous() in GPU.
153+
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
154+
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
155+
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
144156
else:
145157
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
146158

@@ -171,12 +183,15 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
171183

172184

173185
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
174-
def init__(
186+
def __init__(
175187
self,
176188
base_layer: QKVParallelLinear,
177189
lora_backend: BaseLoRABackend,
178190
) -> None:
179191
super().__init__(base_layer, lora_backend)
192+
self.output_offset = None
193+
self.B_buffer_qkv = None
194+
self.max_qkv_out_dim = 0
180195

181196
def set_lora_info(
182197
self,
@@ -194,9 +209,27 @@ def set_lora_info(
194209
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
195210

196211
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
197-
self.B_buffer_qkv = torch.cat(
198-
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
199-
).contiguous()
212+
# self.B_buffer_qkv = torch.cat(
213+
# (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
214+
# ).contiguous()
215+
216+
if self.B_buffer_qkv is None:
217+
self.B_buffer_qkv = torch.empty(
218+
(
219+
B_buffer_q[0].shape[0],
220+
output_dim_q + 2 * output_dim_kv,
221+
B_buffer_q[0].shape[2],
222+
),
223+
dtype=B_buffer_q[0].dtype,
224+
device=B_buffer_q[0].device,
225+
).contiguous()
226+
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
227+
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
228+
B_buffer_kv[0]
229+
)
230+
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
231+
B_buffer_kv[1]
232+
)
200233

201234
# Offsets of q/k/v in output dimension
202235
self.output_offset = torch.tensor(

python/sglang/srt/lora/lora_manager.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ def __init__(
5353
lora_backend: str = "triton",
5454
tp_size: int = 1,
5555
tp_rank: int = 0,
56+
max_bs_in_cuda_graph: int = 0,
5657
):
5758
self.base_model: torch.nn.Module = base_model
5859
self.lora_paths: Dict[str, str] = lora_paths
5960
self.base_hf_config: AutoConfig = base_hf_config
6061
self.max_loras_per_batch: int = max_loras_per_batch
62+
self.max_bs_in_cuda_graph: int = max_bs_in_cuda_graph
6163
self.load_config: LoadConfig = load_config
6264
self.dtype: torch.dtype = dtype
6365
self.device: torch.device = next(self.base_model.parameters()).device
@@ -72,6 +74,23 @@ def __init__(
7274
self.init_loras()
7375
self.init_lora_memory_pool()
7476

77+
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
78+
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
79+
with torch.device("cuda"):
80+
self.cuda_graph_batch_info = LoRABatchInfo(
81+
bs=self.max_bs_in_cuda_graph,
82+
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
83+
seg_indptr=torch.zeros(
84+
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
85+
),
86+
max_len=0,
87+
weight_indices=torch.zeros(
88+
self.max_bs_in_cuda_graph, dtype=torch.int32
89+
),
90+
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
91+
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
92+
)
93+
7594
def init_loras(self):
7695
# Config of each LoRA adapter
7796
self.configs: Dict[str, LoRAConfig] = {}
@@ -140,39 +159,71 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
140159
if cur_uids == set([None]):
141160
return
142161

143-
# set up batch info shared by all lora moruldes
162+
# set up batch info shared by all lora modules
144163
bs = forward_batch.batch_size
145-
seg_lens = (
146-
forward_batch.extend_seq_lens
147-
if forward_batch.forward_mode.is_extend()
148-
else torch.ones(bs, device=self.device)
149-
)
150-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
151-
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
152-
max_len = int(torch.max(seg_lens))
153-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154164

155-
lora_ranks = torch.empty(
156-
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
157-
)
158-
scalings = torch.empty(
159-
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
160-
)
161-
for i, lora_path in enumerate(forward_batch.lora_paths):
162-
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163-
lora = self.loras[lora_path]
164-
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165-
scalings[weight_indices[i]] = lora.scaling
166-
167-
batch_info = LoRABatchInfo(
168-
bs=bs,
169-
seg_lens=seg_lens,
170-
seg_indptr=seg_indptr,
171-
max_len=max_len,
172-
weight_indices=weight_indices,
173-
lora_ranks=lora_ranks,
174-
scalings=scalings,
175-
)
165+
if bs <= self.max_bs_in_cuda_graph:
166+
# Do in-place update for cuda graph
167+
self.cuda_graph_batch_info.bs = bs
168+
if forward_batch.forward_mode.is_extend():
169+
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
170+
forward_batch.extend_seq_lens
171+
)
172+
else:
173+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
174+
self.cuda_graph_batch_info.seg_indptr[0] = 0
175+
torch.cumsum(
176+
self.cuda_graph_batch_info.seg_lens[:bs],
177+
dim=0,
178+
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
179+
)
180+
self.cuda_graph_batch_info.max_len = int(
181+
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
182+
)
183+
184+
for i, lora_path in enumerate(forward_batch.lora_paths):
185+
self.cuda_graph_batch_info.weight_indices[i] = (
186+
self.memory_pool.get_buffer_id(lora_path)
187+
)
188+
lora = self.loras[lora_path]
189+
self.cuda_graph_batch_info.lora_ranks[
190+
self.cuda_graph_batch_info.weight_indices[i]
191+
] = lora.config.hf_config["r"]
192+
self.cuda_graph_batch_info.scalings[
193+
self.cuda_graph_batch_info.weight_indices[i]
194+
] = lora.scaling
195+
batch_info = self.cuda_graph_batch_info
196+
else:
197+
seg_lens = (
198+
forward_batch.extend_seq_lens
199+
if forward_batch.forward_mode.is_extend()
200+
else torch.ones(bs, device=self.device)
201+
)
202+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
203+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
204+
max_len = int(torch.max(seg_lens))
205+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
206+
207+
lora_ranks = torch.empty(
208+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
209+
)
210+
scalings = torch.empty(
211+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
212+
)
213+
for i, lora_path in enumerate(forward_batch.lora_paths):
214+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
215+
lora = self.loras[lora_path]
216+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
217+
scalings[weight_indices[i]] = lora.scaling
218+
batch_info = LoRABatchInfo(
219+
bs=bs,
220+
seg_lens=seg_lens,
221+
seg_indptr=seg_indptr,
222+
max_len=max_len,
223+
weight_indices=weight_indices,
224+
lora_ranks=lora_ranks,
225+
scalings=scalings,
226+
)
176227
self.lora_backend.set_batch_info(batch_info)
177228

178229
# call set_lora_info for each lora modules

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ def __init__(self, model_runner: ModelRunner):
220220
if self.enable_torch_compile:
221221
set_torch_compile_config()
222222

223+
if self.model_runner.server_args.lora_paths is not None:
224+
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
225+
223226
# Graph inputs
224227
with torch.device("cuda"):
225228
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -403,6 +406,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
403406
self.capture_hidden_mode = (
404407
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
405408
)
409+
if self.model_runner.server_args.lora_paths is not None:
410+
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
411+
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
412+
# values if lora is enabled.
413+
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
414+
else:
415+
lora_paths = None
406416

407417
forward_batch = ForwardBatch(
408418
forward_mode=self.capture_forward_mode,
@@ -424,8 +434,12 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
424434
spec_algorithm=self.model_runner.spec_algorithm,
425435
spec_info=spec_info,
426436
capture_hidden_mode=self.capture_hidden_mode,
437+
lora_paths=lora_paths,
427438
)
428439

440+
if lora_paths is not None:
441+
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
442+
429443
# Attention backend
430444
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
431445
bs,

python/sglang/srt/server_args.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,6 @@ def check_server_args(self):
12301230
assert (
12311231
self.max_loras_per_batch > 0
12321232
# FIXME
1233-
and (self.lora_paths is None or self.disable_cuda_graph)
12341233
and (self.lora_paths is None or self.disable_radix_cache)
12351234
), "compatibility of lora and cuda graph and radix attention is in progress"
12361235
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"

0 commit comments

Comments
 (0)