Skip to content

Commit 3196999

Browse files
authored
Reduce computation and communication in DP attention (#4521)
1 parent 9e0186f commit 3196999

File tree

5 files changed

+65
-75
lines changed

5 files changed

+65
-75
lines changed

python/sglang/srt/distributed/parallel_state.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ class GroupCoordinator:
189189
device_group: ProcessGroup # group for device communication
190190
use_pynccl: bool # a hint of whether to use PyNccl
191191
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
192+
use_message_queue_broadcaster: (
193+
bool # a hint of whether to use message queue broadcaster
194+
)
192195
# communicators are only created for world size > 1
193196
pynccl_comm: Optional[Any] # PyNccl communicator
194197
ca_comm: Optional[Any] # Custom allreduce communicator
@@ -241,6 +244,7 @@ def __init__(
241244
self.use_custom_allreduce = use_custom_allreduce
242245
self.use_hpu_communicator = use_hpu_communicator
243246
self.use_xpu_communicator = use_xpu_communicator
247+
self.use_message_queue_broadcaster = use_message_queue_broadcaster
244248

245249
# lazy import to avoid documentation build error
246250
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
@@ -269,15 +273,15 @@ def __init__(
269273
HpuCommunicator,
270274
)
271275

272-
self.hpu_communicator: Optional[HpuCommunicator]
276+
self.hpu_communicator: Optional[HpuCommunicator] = None
273277
if use_hpu_communicator and self.world_size > 1:
274278
self.hpu_communicator = HpuCommunicator(group=self.device_group)
275279

276280
from sglang.srt.distributed.device_communicators.xpu_communicator import (
277281
XpuCommunicator,
278282
)
279283

280-
self.xpu_communicator: Optional[XpuCommunicator]
284+
self.xpu_communicator: Optional[XpuCommunicator] = None
281285
if use_xpu_communicator and self.world_size > 1:
282286
self.xpu_communicator = XpuCommunicator(group=self.device_group)
283287

python/sglang/srt/layers/dp_attention.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,8 @@ def initialize_dp_attention(
5353
)
5454

5555
if enable_dp_attention:
56-
local_rank = tp_rank % (tp_size // dp_size)
5756
_DP_SIZE = dp_size
5857
else:
59-
local_rank = tp_rank
6058
_DP_SIZE = 1
6159

6260
tp_group = get_tp_group()
@@ -65,7 +63,7 @@ def initialize_dp_attention(
6563
list(range(head, head + _ATTN_TP_SIZE))
6664
for head in range(0, tp_size, _ATTN_TP_SIZE)
6765
],
68-
local_rank,
66+
tp_group.local_rank,
6967
torch.distributed.get_backend(tp_group.device_group),
7068
SYNC_TOKEN_IDS_ACROSS_TP,
7169
False,
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
180178
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
181179

182180

183-
def dp_gather(
181+
def _dp_gather(
184182
global_tokens: torch.Tensor,
185183
local_tokens: torch.Tensor,
186184
forward_batch: ForwardBatch,
187-
layer_id: Union[str, int],
185+
is_partial: bool,
188186
):
189187
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
190188

191189
global_tokens.fill_(0)
192190
assert local_tokens.is_contiguous()
193191
assert global_tokens.is_contiguous()
194-
if local_tokens.shape[0] > 0 and (
195-
layer_id != "embedding" or get_attention_tp_rank() == 0
196-
):
192+
193+
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
197194
assert (
198195
global_tokens.untyped_storage().data_ptr()
199196
!= local_tokens.untyped_storage().data_ptr()
@@ -216,6 +213,22 @@ def dp_gather(
216213
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
217214

218215

216+
def dp_gather_partial(
217+
global_tokens: torch.Tensor,
218+
local_tokens: torch.Tensor,
219+
forward_batch: ForwardBatch,
220+
):
221+
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
222+
223+
224+
def dp_gather_replicate(
225+
global_tokens: torch.Tensor,
226+
local_tokens: torch.Tensor,
227+
forward_batch: ForwardBatch,
228+
):
229+
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
230+
231+
219232
def dp_scatter(
220233
local_tokens: torch.Tensor, # output
221234
global_tokens: torch.Tensor, # input
@@ -236,16 +249,3 @@ def dp_scatter(
236249
memcpy_triton(
237250
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
238251
)
239-
240-
241-
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
242-
def do_logits_dp_scatter(logits: torch.Tensor):
243-
local_logits = torch.empty(
244-
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
245-
dtype=logits.dtype,
246-
device=logits.device,
247-
)
248-
dp_scatter(local_logits, logits, forward_batch)
249-
return local_logits
250-
251-
return do_logits_dp_scatter

python/sglang/srt/layers/logits_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
tensor_model_parallel_all_gather,
2929
)
3030
from sglang.srt.layers.dp_attention import (
31-
dp_gather,
31+
dp_gather_replicate,
3232
dp_scatter,
3333
get_attention_dp_rank,
3434
get_attention_dp_size,
@@ -428,7 +428,7 @@ def _get_logits(
428428
logits_metadata.gathered_buffer,
429429
hidden_states.clone(),
430430
)
431-
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
431+
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
432432

433433
if hasattr(lm_head, "weight"):
434434
logits = torch.matmul(

python/sglang/srt/models/deepseek_v2.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
decode_attention_fwd_grouped_rope,
3434
)
3535
from sglang.srt.layers.dp_attention import (
36-
dp_gather,
36+
dp_gather_partial,
3737
dp_scatter,
3838
get_attention_dp_size,
3939
get_attention_tp_rank,
@@ -939,47 +939,58 @@ def forward(
939939
forward_batch: ForwardBatch,
940940
residual: Optional[torch.Tensor],
941941
) -> torch.Tensor:
942-
if residual is None:
942+
if hidden_states.shape[0] == 0:
943943
residual = hidden_states
944-
hidden_states = self.input_layernorm(hidden_states)
945944
else:
946-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
945+
if residual is None:
946+
residual = hidden_states
947+
hidden_states = self.input_layernorm(hidden_states)
948+
else:
949+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
947950

948-
# Scatter
949-
if self.dp_size != 1:
950-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
951-
# be careful about this!
952-
hidden_states, global_hidden_states = (
953-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
954-
hidden_states,
951+
# Self Attention
952+
hidden_states = self.self_attn(
953+
positions=positions,
954+
hidden_states=hidden_states,
955+
forward_batch=forward_batch,
955956
)
956-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
957-
958-
# Self Attention
959-
hidden_states = self.self_attn(
960-
positions=positions,
961-
hidden_states=hidden_states,
962-
forward_batch=forward_batch,
963-
)
964957

965958
# Gather
966959
if get_tensor_model_parallel_world_size() > 1:
967960
# all gather and all reduce
968961
if self.dp_size != 1:
962+
if get_attention_tp_rank() == 0:
963+
hidden_states += residual
969964
hidden_states, local_hidden_states = (
970965
forward_batch.gathered_buffer,
971966
hidden_states,
972967
)
973-
dp_gather(
974-
hidden_states, local_hidden_states, forward_batch, self.layer_id
975-
)
968+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
969+
dp_scatter(residual, hidden_states, forward_batch)
970+
hidden_states = self.post_attention_layernorm(hidden_states)
976971
else:
977972
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
978-
979-
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
973+
hidden_states, residual = self.post_attention_layernorm(
974+
hidden_states, residual
975+
)
976+
else:
977+
hidden_states, residual = self.post_attention_layernorm(
978+
hidden_states, residual
979+
)
980980

981981
# Fully Connected
982982
hidden_states = self.mlp(hidden_states)
983+
984+
# Scatter
985+
if self.dp_size != 1:
986+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
987+
# be careful about this!
988+
hidden_states, global_hidden_states = (
989+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
990+
hidden_states,
991+
)
992+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
993+
983994
return hidden_states, residual
984995

985996

@@ -1025,18 +1036,6 @@ def forward(
10251036
input_embeds: torch.Tensor = None,
10261037
) -> torch.Tensor:
10271038

1028-
# Gather
1029-
if self.dp_size != 1:
1030-
input_ids, local_input_ids = (
1031-
torch.empty(
1032-
(forward_batch.gathered_buffer.shape[0],),
1033-
dtype=input_ids.dtype,
1034-
device=input_ids.device,
1035-
),
1036-
input_ids,
1037-
)
1038-
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
1039-
10401039
if input_embeds is None:
10411040
hidden_states = self.embed_tokens(input_ids)
10421041
else:
@@ -1087,15 +1086,6 @@ def forward(
10871086

10881087
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
10891088

1090-
if self.dp_size != 1:
1091-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
1092-
# be careful about this!
1093-
hidden_states, global_hidden_states = (
1094-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1095-
hidden_states,
1096-
)
1097-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
1098-
10991089
return self.logits_processor(
11001090
input_ids, hidden_states, self.lm_head, forward_batch
11011091
)

test/srt/test_dp_attention.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313

14-
class TestDPAttention(unittest.TestCase):
14+
class TestDPAttentionDP2TP2(unittest.TestCase):
1515
@classmethod
1616
def setUpClass(cls):
1717
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
@@ -59,7 +59,3 @@ def test_mgsm_en(self):
5959
metrics = run_eval(args)
6060
print(f"{metrics=}")
6161
self.assertGreater(metrics["score"], 0.8)
62-
63-
64-
if __name__ == "__main__":
65-
unittest.main()

0 commit comments

Comments
 (0)