@@ -455,8 +455,12 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
455
455
456
456
457
457
class TextModel (ModelBase ):
458
+ model_type = ModelType .TEXT
459
+ hf_arch : str
460
+
458
461
def __init__ (self , * args , ** kwargs ):
459
462
super ().__init__ (* args , ** kwargs )
463
+ self .hf_arch = get_model_architecture (self .hparams , self .model_type )
460
464
461
465
if "text_config" in self .hparams :
462
466
# move the text_config to the root level
@@ -1075,10 +1079,36 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
1075
1079
if (field := vocab_reader .get_field (gguf .Keys .Tokenizer .ADD_EOS )) is not None :
1076
1080
self .gguf_writer .add_add_eos_token (field .parts [- 1 ].tolist ()[0 ])
1077
1081
1082
+ def _try_set_pooling_type (self ) -> None :
1083
+ # get pooling path
1084
+ pooling_path = None
1085
+ module_path = self .dir_model / "modules.json"
1086
+ if module_path .is_file ():
1087
+ with open (module_path , encoding = "utf-8" ) as f :
1088
+ modules = json .load (f )
1089
+ for mod in modules :
1090
+ if mod ["type" ] == "sentence_transformers.models.Pooling" :
1091
+ pooling_path = mod ["path" ]
1092
+ break
1093
+
1094
+ # get pooling type
1095
+ if pooling_path is not None :
1096
+ with open (self .dir_model / pooling_path / "config.json" , encoding = "utf-8" ) as f :
1097
+ pooling = json .load (f )
1098
+ if pooling ["pooling_mode_mean_tokens" ]:
1099
+ pooling_type = gguf .PoolingType .MEAN
1100
+ elif pooling ["pooling_mode_cls_token" ]:
1101
+ pooling_type = gguf .PoolingType .CLS
1102
+ elif pooling ["pooling_mode_lasttoken" ]:
1103
+ pooling_type = gguf .PoolingType .LAST
1104
+ else :
1105
+ raise NotImplementedError ("Only MEAN, CLS, and LAST pooling types supported" )
1106
+ self .gguf_writer .add_pooling_type (pooling_type )
1107
+
1078
1108
1079
1109
class VisionModel (ModelBase ):
1110
+ model_type = ModelType .VISION
1080
1111
model_arch = gguf .MODEL_ARCH .CLIP_VISION
1081
- n_text_embd = 0
1082
1112
preprocessor_config : dict [str , Any ]
1083
1113
global_config : dict [str , Any ]
1084
1114
@@ -2542,7 +2572,7 @@ def set_gguf_parameters(self):
2542
2572
self .gguf_writer .add_file_type (self .ftype )
2543
2573
2544
2574
2545
- @ModelBase .register ("Qwen2ForCausalLM" )
2575
+ @ModelBase .register ("Qwen2Model" , " Qwen2ForCausalLM" )
2546
2576
class Qwen2Model (TextModel ):
2547
2577
model_arch = gguf .MODEL_ARCH .QWEN2
2548
2578
@@ -2554,12 +2584,18 @@ def set_vocab(self):
2554
2584
2555
2585
def set_gguf_parameters (self ):
2556
2586
super ().set_gguf_parameters ()
2587
+ self ._try_set_pooling_type ()
2557
2588
if self .hparams .get ("rope_scaling" ) is not None and "factor" in self .hparams ["rope_scaling" ]:
2558
2589
if self .hparams ["rope_scaling" ].get ("type" ) == "yarn" :
2559
2590
self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
2560
2591
self .gguf_writer .add_rope_scaling_factor (self .hparams ["rope_scaling" ]["factor" ])
2561
2592
self .gguf_writer .add_rope_scaling_orig_ctx_len (self .hparams ["rope_scaling" ]["original_max_position_embeddings" ])
2562
2593
2594
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2595
+ if self .hf_arch == "Qwen2Model" :
2596
+ name = f"model.{ name } " # map to Qwen2ForCausalLM tensors
2597
+ yield from super ().modify_tensors (data_torch , name , bid )
2598
+
2563
2599
2564
2600
@ModelBase .register ("Qwen2VLForConditionalGeneration" , "Qwen2_5_VLForConditionalGeneration" )
2565
2601
class Qwen2VLModel (TextModel ):
@@ -3396,29 +3432,7 @@ def __init__(self, *args, **kwargs):
3396
3432
def set_gguf_parameters (self ):
3397
3433
super ().set_gguf_parameters ()
3398
3434
self .gguf_writer .add_causal_attention (False )
3399
-
3400
- # get pooling path
3401
- pooling_path = None
3402
- module_path = self .dir_model / "modules.json"
3403
- if module_path .is_file ():
3404
- with open (module_path , encoding = "utf-8" ) as f :
3405
- modules = json .load (f )
3406
- for mod in modules :
3407
- if mod ["type" ] == "sentence_transformers.models.Pooling" :
3408
- pooling_path = mod ["path" ]
3409
- break
3410
-
3411
- # get pooling type
3412
- if pooling_path is not None :
3413
- with open (self .dir_model / pooling_path / "config.json" , encoding = "utf-8" ) as f :
3414
- pooling = json .load (f )
3415
- if pooling ["pooling_mode_mean_tokens" ]:
3416
- pooling_type = gguf .PoolingType .MEAN
3417
- elif pooling ["pooling_mode_cls_token" ]:
3418
- pooling_type = gguf .PoolingType .CLS
3419
- else :
3420
- raise NotImplementedError ("Only MEAN and CLS pooling types supported" )
3421
- self .gguf_writer .add_pooling_type (pooling_type )
3435
+ self ._try_set_pooling_type ()
3422
3436
3423
3437
def set_vocab (self ):
3424
3438
tokens , toktypes , tokpre = self .get_vocab_base ()
@@ -5962,8 +5976,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
5962
5976
return n
5963
5977
5964
5978
5965
- def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5966
- hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5979
+ def get_model_architecture (hparams : dict [str , Any ], model_type : ModelType ) -> str :
5967
5980
text_config = hparams .get ("text_config" , {})
5968
5981
vision_config = hparams .get ("vision_config" , {})
5969
5982
arch = hparams ["architectures" ][0 ]
@@ -6034,7 +6047,8 @@ def main() -> None:
6034
6047
with torch .inference_mode ():
6035
6048
output_type = ftype_map [args .outtype ]
6036
6049
model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
6037
- model_architecture = get_model_architecture (dir_model , model_type )
6050
+ hparams = ModelBase .load_hparams (dir_model )
6051
+ model_architecture = get_model_architecture (hparams , model_type )
6038
6052
logger .info (f"Model architecture: { model_architecture } " )
6039
6053
try :
6040
6054
model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
0 commit comments