Skip to content

Commit 2f56761

Browse files
authored
llama-model : support Qwen2 embedding models and pooling_mode_lasttoken (ggml-org#13245)
1 parent 7d21234 commit 2f56761

File tree

3 files changed

+45
-28
lines changed

3 files changed

+45
-28
lines changed

convert_hf_to_gguf.py

+42-28
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,12 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
455455

456456

457457
class TextModel(ModelBase):
458+
model_type = ModelType.TEXT
459+
hf_arch: str
460+
458461
def __init__(self, *args, **kwargs):
459462
super().__init__(*args, **kwargs)
463+
self.hf_arch = get_model_architecture(self.hparams, self.model_type)
460464

461465
if "text_config" in self.hparams:
462466
# move the text_config to the root level
@@ -1075,10 +1079,36 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
10751079
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
10761080
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
10771081

1082+
def _try_set_pooling_type(self) -> None:
1083+
# get pooling path
1084+
pooling_path = None
1085+
module_path = self.dir_model / "modules.json"
1086+
if module_path.is_file():
1087+
with open(module_path, encoding="utf-8") as f:
1088+
modules = json.load(f)
1089+
for mod in modules:
1090+
if mod["type"] == "sentence_transformers.models.Pooling":
1091+
pooling_path = mod["path"]
1092+
break
1093+
1094+
# get pooling type
1095+
if pooling_path is not None:
1096+
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
1097+
pooling = json.load(f)
1098+
if pooling["pooling_mode_mean_tokens"]:
1099+
pooling_type = gguf.PoolingType.MEAN
1100+
elif pooling["pooling_mode_cls_token"]:
1101+
pooling_type = gguf.PoolingType.CLS
1102+
elif pooling["pooling_mode_lasttoken"]:
1103+
pooling_type = gguf.PoolingType.LAST
1104+
else:
1105+
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
1106+
self.gguf_writer.add_pooling_type(pooling_type)
1107+
10781108

10791109
class VisionModel(ModelBase):
1110+
model_type = ModelType.VISION
10801111
model_arch = gguf.MODEL_ARCH.CLIP_VISION
1081-
n_text_embd = 0
10821112
preprocessor_config: dict[str, Any]
10831113
global_config: dict[str, Any]
10841114

@@ -2542,7 +2572,7 @@ def set_gguf_parameters(self):
25422572
self.gguf_writer.add_file_type(self.ftype)
25432573

25442574

2545-
@ModelBase.register("Qwen2ForCausalLM")
2575+
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
25462576
class Qwen2Model(TextModel):
25472577
model_arch = gguf.MODEL_ARCH.QWEN2
25482578

@@ -2554,12 +2584,18 @@ def set_vocab(self):
25542584

25552585
def set_gguf_parameters(self):
25562586
super().set_gguf_parameters()
2587+
self._try_set_pooling_type()
25572588
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
25582589
if self.hparams["rope_scaling"].get("type") == "yarn":
25592590
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
25602591
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
25612592
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
25622593

2594+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2595+
if self.hf_arch == "Qwen2Model":
2596+
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
2597+
yield from super().modify_tensors(data_torch, name, bid)
2598+
25632599

25642600
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
25652601
class Qwen2VLModel(TextModel):
@@ -3396,29 +3432,7 @@ def __init__(self, *args, **kwargs):
33963432
def set_gguf_parameters(self):
33973433
super().set_gguf_parameters()
33983434
self.gguf_writer.add_causal_attention(False)
3399-
3400-
# get pooling path
3401-
pooling_path = None
3402-
module_path = self.dir_model / "modules.json"
3403-
if module_path.is_file():
3404-
with open(module_path, encoding="utf-8") as f:
3405-
modules = json.load(f)
3406-
for mod in modules:
3407-
if mod["type"] == "sentence_transformers.models.Pooling":
3408-
pooling_path = mod["path"]
3409-
break
3410-
3411-
# get pooling type
3412-
if pooling_path is not None:
3413-
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
3414-
pooling = json.load(f)
3415-
if pooling["pooling_mode_mean_tokens"]:
3416-
pooling_type = gguf.PoolingType.MEAN
3417-
elif pooling["pooling_mode_cls_token"]:
3418-
pooling_type = gguf.PoolingType.CLS
3419-
else:
3420-
raise NotImplementedError("Only MEAN and CLS pooling types supported")
3421-
self.gguf_writer.add_pooling_type(pooling_type)
3435+
self._try_set_pooling_type()
34223436

34233437
def set_vocab(self):
34243438
tokens, toktypes, tokpre = self.get_vocab_base()
@@ -5962,8 +5976,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
59625976
return n
59635977

59645978

5965-
def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
5966-
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
5979+
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
59675980
text_config = hparams.get("text_config", {})
59685981
vision_config = hparams.get("vision_config", {})
59695982
arch = hparams["architectures"][0]
@@ -6034,7 +6047,8 @@ def main() -> None:
60346047
with torch.inference_mode():
60356048
output_type = ftype_map[args.outtype]
60366049
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
6037-
model_architecture = get_model_architecture(dir_model, model_type)
6050+
hparams = ModelBase.load_hparams(dir_model)
6051+
model_architecture = get_model_architecture(hparams, model_type)
60386052
logger.info(f"Model architecture: {model_architecture}")
60396053
try:
60406054
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)

gguf-py/gguf/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,8 @@ class PoolingType(IntEnum):
20332033
NONE = 0
20342034
MEAN = 1
20352035
CLS = 2
2036+
LAST = 3
2037+
RANK = 4
20362038

20372039

20382040
class GGMLQuantizationType(IntEnum):

src/llama-model.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
773773
// fall through
774774
case LLM_ARCH_QWEN2:
775775
{
776+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
776777
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
777778
switch (hparams.n_layer) {
778779
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;

0 commit comments

Comments
 (0)