Skip to content

Commit 072df75

Browse files
authored
Support for Qwen2.5-VL Model in bitsandbytes Format (#5003)
1 parent defede5 commit 072df75

File tree

6 files changed

+375
-45
lines changed

6 files changed

+375
-45
lines changed

.github/workflows/vllm-dependency-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
run: |
3434
bash scripts/ci_install_dependency.sh
3535
pip install "vllm>=0.6.4.post1,<=0.7.2"
36+
pip install "bitsandbytes>=0.44.0"
3637
3738
- name: Run VLLM dependency tests
3839
timeout-minutes: 60

python/sglang/srt/model_loader/loader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None:
10711071

10721072
param_dict = dict(model.named_parameters())
10731073
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
1074+
model_type = model_config.hf_config.model_type
10741075
for quant_param_name in quant_state_dict:
10751076
non_stacked_param_name = quant_param_name
10761077

@@ -1079,11 +1080,24 @@ def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None:
10791080
weight_name,
10801081
index,
10811082
) in model.bitsandbytes_stacked_params_mapping.items():
1083+
if (
1084+
model_type in ["qwen2_vl", "qwen2_5_vl"]
1085+
and "visual" in quant_param_name
1086+
):
1087+
break
10821088
if shard_name in quant_param_name:
10831089
shard_index = index
10841090
quant_param_name = quant_param_name.replace(shard_name, weight_name)
10851091
break
10861092

1093+
if (
1094+
model_type in ["qwen2_vl", "qwen2_5_vl"]
1095+
and "visual" in quant_param_name
1096+
):
1097+
quant_param_name = quant_param_name.replace(
1098+
r"attn.qkv.", r"attn.qkv_proj."
1099+
)
1100+
10871101
if quant_param_name not in param_dict:
10881102
raise ValueError(
10891103
f"Parameter {quant_param_name} not found in the model."
@@ -1111,6 +1125,8 @@ def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None:
11111125
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
11121126

11131127
offsets = np.concatenate(([0], np.cumsum(num_elements)))
1128+
# Make torch infer_schema happy(Compatible with vLLM)
1129+
offsets = torch.tensor(offsets).cpu()
11141130
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
11151131

11161132
if load_8bit:

python/sglang/srt/models/qwen2_5_vl.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
embed_dim=dim,
142142
num_heads=num_heads,
143143
projection_size=dim,
144-
use_qkv_parallel=False,
144+
use_qkv_parallel=True,
145145
use_context_forward=use_context_forward,
146146
softmax_in_single_precision=softmax_in_single_precision,
147147
flatten_batch=flatten_batch,
@@ -325,7 +325,7 @@ def get_window_index(self, grid_thw):
325325

326326
@property
327327
def dtype(self) -> torch.dtype:
328-
return self.blocks[0].mlp.gate_proj.weight.dtype
328+
return self.patch_embed.proj.weight.dtype
329329

330330
@property
331331
def device(self) -> torch.device:
@@ -429,6 +429,25 @@ def forward(
429429

430430

431431
class Qwen2_5_VLForConditionalGeneration(nn.Module):
432+
# BitandBytes specific attributes
433+
default_bitsandbytes_target_modules = [
434+
".gate_proj.",
435+
".down_proj.",
436+
".up_proj.",
437+
".q_proj.",
438+
".k_proj.",
439+
".v_proj.",
440+
".o_proj.",
441+
]
442+
bitsandbytes_stacked_params_mapping = {
443+
# shard_name, weight_name, index
444+
"q_proj": ("qkv_proj", 0),
445+
"k_proj": ("qkv_proj", 1),
446+
"v_proj": ("qkv_proj", 2),
447+
"gate_proj": ("gate_up_proj", 0),
448+
"up_proj": ("gate_up_proj", 1),
449+
}
450+
432451
def __init__(
433452
self,
434453
config: Qwen2_5_VLConfig,
@@ -441,9 +460,9 @@ def __init__(
441460
self.visual = Qwen2_5_VisionTransformer(
442461
config.vision_config,
443462
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
444-
# NOTE: Qwen2-VL vision encoder does not support any
445-
# quantization method now.
446-
quant_config=None,
463+
# NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
464+
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
465+
quant_config=quant_config,
447466
prefix=add_prefix("visual", prefix),
448467
)
449468

@@ -573,23 +592,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
573592
weight_loader(param, loaded_weight, shard_id)
574593
break
575594
else:
576-
if "visual" in name and "qkv.weight" in name:
577-
visual_num_heads = self.config.vision_config.num_heads
578-
visual_embed_dim = self.config.vision_config.hidden_size
579-
head_size = visual_embed_dim // visual_num_heads
580-
loaded_weight = loaded_weight.view(
581-
3, visual_num_heads, head_size, visual_embed_dim
582-
)
583-
loaded_weight = loaded_weight.transpose(0, 1)
584-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
585-
elif "visual" in name and "qkv.bias" in name:
586-
visual_num_heads = self.config.vision_config.num_heads
587-
visual_embed_dim = self.config.vision_config.hidden_size
588-
head_size = visual_embed_dim // visual_num_heads
589-
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
590-
loaded_weight = loaded_weight.transpose(0, 1)
591-
loaded_weight = loaded_weight.reshape(-1)
592-
593595
if "visual" in name:
594596
# adapt to VisionAttention
595597
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")

python/sglang/srt/models/qwen2_vl.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __init__(
152152
embed_dim=dim,
153153
num_heads=num_heads,
154154
projection_size=dim,
155-
use_qkv_parallel=False,
155+
use_qkv_parallel=True,
156156
use_context_forward=use_context_forward,
157157
softmax_in_single_precision=softmax_in_single_precision,
158158
flatten_batch=True,
@@ -351,7 +351,7 @@ def __init__(
351351

352352
@property
353353
def dtype(self) -> torch.dtype:
354-
return next(self.parameters()).dtype
354+
return self.patch_embed.proj.weight.dtype
355355

356356
@property
357357
def device(self) -> torch.device:
@@ -423,6 +423,25 @@ def forward(
423423

424424

425425
class Qwen2VLForConditionalGeneration(nn.Module):
426+
# BitandBytes specific attributes
427+
default_bitsandbytes_target_modules = [
428+
".gate_proj.",
429+
".down_proj.",
430+
".up_proj.",
431+
".q_proj.",
432+
".k_proj.",
433+
".v_proj.",
434+
".o_proj.",
435+
]
436+
bitsandbytes_stacked_params_mapping = {
437+
# shard_name, weight_name, index
438+
"q_proj": ("qkv_proj", 0),
439+
"k_proj": ("qkv_proj", 1),
440+
"v_proj": ("qkv_proj", 2),
441+
"gate_proj": ("gate_up_proj", 0),
442+
"up_proj": ("gate_up_proj", 1),
443+
}
444+
426445
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
427446
processor = cached_get_processor(self.config._name_or_path)
428447
grid_t, grid_h, grid_w = image_grid_thw
@@ -447,9 +466,9 @@ def __init__(
447466
self.visual = Qwen2VisionTransformer(
448467
config.vision_config,
449468
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
450-
# NOTE: Qwen2-VL vision encoder does not support any
451-
# quantization method now.
452-
quant_config=None,
469+
# NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
470+
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
471+
quant_config=quant_config,
453472
prefix=add_prefix("visual", prefix),
454473
)
455474

@@ -578,24 +597,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
578597
weight_loader(param, loaded_weight, shard_id)
579598
break
580599
else:
581-
582-
if "visual" in name and "qkv.weight" in name:
583-
visual_num_heads = self.config.vision_config.num_heads
584-
visual_embed_dim = self.config.vision_config.embed_dim
585-
head_size = visual_embed_dim // visual_num_heads
586-
loaded_weight = loaded_weight.view(
587-
3, visual_num_heads, head_size, visual_embed_dim
588-
)
589-
loaded_weight = loaded_weight.transpose(0, 1)
590-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
591-
elif "visual" in name and "qkv.bias" in name:
592-
visual_num_heads = self.config.vision_config.num_heads
593-
visual_embed_dim = self.config.vision_config.embed_dim
594-
head_size = visual_embed_dim // visual_num_heads
595-
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
596-
loaded_weight = loaded_weight.transpose(0, 1)
597-
loaded_weight = loaded_weight.reshape(-1)
598-
599600
if "visual" in name:
600601
# adapt to VisionAttention
601602
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class TestFile:
9797
TestFile("test_awq.py"),
9898
TestFile("test_gguf.py", 78),
9999
TestFile("test_gptqmodel_dynamic.py", 72),
100+
TestFile("test_bnb.py"),
100101
],
101102
}
102103

0 commit comments

Comments
 (0)