Skip to content

Commit 33adf70

Browse files
yubofredwangjimoosciuc
authored andcommitted
Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 (sgl-project#4760)
1 parent 6f05ba6 commit 33adf70

File tree

6 files changed

+739
-230
lines changed

6 files changed

+739
-230
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,9 @@ def forward_extend(
548548
# Use Flash Attention for prefill
549549
if not self.use_mla:
550550
# Do multi-head attention
551-
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
552-
key_cache, value_cache = kv_cache[0], kv_cache[1]
551+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
552+
layer.layer_id
553+
)
553554
key_cache = key_cache.view(
554555
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
555556
)
@@ -592,7 +593,6 @@ def forward_extend(
592593
c_kv_cache = c_kv.view(
593594
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
594595
)
595-
596596
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
597597
q_nope = q_all[:, :, : layer.v_head_dim]
598598
q_rope = q_all[:, :, layer.v_head_dim :]
@@ -659,8 +659,10 @@ def forward_decode(
659659

660660
if not self.use_mla:
661661
# Do multi-head attention
662-
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
663-
key_cache, value_cache = kv_cache[0], kv_cache[1]
662+
663+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
664+
layer.layer_id
665+
)
664666
key_cache = key_cache.view(
665667
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
666668
)

python/sglang/srt/layers/quantization/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ def override_quantization_method(self, *args, **kwargs):
6363
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
6464
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
6565
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66-
from sglang.srt.layers.vocab_parallel_embedding import (
67-
ParallelLMHead,
68-
UnquantizedEmbeddingMethod,
69-
)
7066

7167
# Base quantization methods that don't depend on vllm
7268
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
@@ -176,6 +172,13 @@ def get_linear_quant_method(
176172
prefix: str,
177173
linear_method_cls: type,
178174
):
175+
# Move import here to avoid circular import. This is only used in monkey patching
176+
# of vllm's QuantizationConfig.
177+
from sglang.srt.layers.vocab_parallel_embedding import (
178+
ParallelLMHead,
179+
UnquantizedEmbeddingMethod,
180+
)
181+
179182
cloned_config = deepcopy(config)
180183
parallel_lm_head_quantized = (
181184
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized

0 commit comments

Comments
 (0)