11
11
import os
12
12
import re
13
13
import sys
14
+ from abc import ABC , abstractmethod
14
15
from enum import IntEnum
15
16
from pathlib import Path
16
17
from hashlib import sha256
@@ -51,7 +52,7 @@ class ModelType(IntEnum):
51
52
AnyModel = TypeVar ("AnyModel" , bound = "type[ModelBase]" )
52
53
53
54
54
- class ModelBase :
55
+ class ModelBase ( ABC ) :
55
56
_model_classes : dict [ModelType , dict [str , type [ModelBase ]]] = {
56
57
ModelType .TEXT : {},
57
58
ModelType .VISION : {},
@@ -81,25 +82,11 @@ class ModelBase:
81
82
block_count : int
82
83
tensor_map : gguf .TensorNameMap
83
84
84
- def __init__ (
85
- self ,
86
- dir_model : Path ,
87
- ftype : gguf .LlamaFileType ,
88
- fname_out : Path ,
89
- hf_arch : str ,
90
- * ,
91
- is_big_endian : bool = False ,
92
- use_temp_file : bool = False ,
93
- eager : bool = False ,
94
- metadata_override : Path | None = None ,
95
- model_name : str | None = None ,
96
- split_max_tensors : int = 0 ,
97
- split_max_size : int = 0 ,
98
- dry_run : bool = False ,
99
- small_first_shard : bool = False ,
100
- hparams : dict [str , Any ] | None = None ,
101
- remote_hf_model_id : str | None = None ,
102
- ):
85
+ def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
86
+ use_temp_file : bool = False , eager : bool = False ,
87
+ metadata_override : Path | None = None , model_name : str | None = None ,
88
+ split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False ,
89
+ small_first_shard : bool = False , hparams : dict [str , Any ] | None = None , remote_hf_model_id : str | None = None ):
103
90
if type (self ) is ModelBase or \
104
91
type (self ) is TextModel or \
105
92
type (self ) is VisionModel :
@@ -108,7 +95,6 @@ def __init__(
108
95
self .dir_model = dir_model
109
96
self .ftype = ftype
110
97
self .fname_out = fname_out
111
- self .hf_arch = hf_arch
112
98
self .is_big_endian = is_big_endian
113
99
self .endianess = gguf .GGUFEndian .BIG if is_big_endian else gguf .GGUFEndian .LITTLE
114
100
self .use_temp_file = use_temp_file
@@ -151,6 +137,11 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
151
137
self .gguf_writer = gguf .GGUFWriter (path = None , arch = gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file ,
152
138
split_max_tensors = split_max_tensors , split_max_size = split_max_size , dry_run = dry_run , small_first_shard = small_first_shard )
153
139
140
+ @property
141
+ @abstractmethod
142
+ def model_type (self ):
143
+ raise NotImplementedError
144
+
154
145
@classmethod
155
146
def add_prefix_to_filename (cls , path : Path , prefix : str ) -> Path :
156
147
stem , suffix = path .stem , path .suffix
@@ -468,8 +459,11 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
468
459
469
460
470
461
class TextModel (ModelBase ):
462
+ model_type = ModelType .TEXT
463
+
471
464
def __init__ (self , * args , ** kwargs ):
472
465
super ().__init__ (* args , ** kwargs )
466
+ self .hf_arch = get_model_architecture (self .hparams , self .model_type )
473
467
474
468
if "text_config" in self .hparams :
475
469
# move the text_config to the root level
@@ -1116,8 +1110,8 @@ def _try_set_pooling_type(self) -> None:
1116
1110
1117
1111
1118
1112
class VisionModel (ModelBase ):
1113
+ model_type = ModelType .VISION
1119
1114
model_arch = gguf .MODEL_ARCH .CLIP_VISION
1120
- n_text_embd = 0
1121
1115
preprocessor_config : dict [str , Any ]
1122
1116
global_config : dict [str , Any ]
1123
1117
@@ -3558,15 +3552,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
3558
3552
class NomicBertModel (BertModel ):
3559
3553
model_arch = gguf .MODEL_ARCH .BERT
3560
3554
3561
- def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , hf_arch : str , ** kwargs : Any ):
3555
+ def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , ** kwargs : Any ):
3562
3556
hparams = kwargs .pop ("hparams" , None )
3563
3557
if hparams is None :
3564
3558
hparams = ModelBase .load_hparams (dir_model )
3565
3559
3566
3560
self .is_moe = bool (hparams .get ("moe_every_n_layers" ))
3567
3561
self .model_arch = gguf .MODEL_ARCH .NOMIC_BERT_MOE if self .is_moe else gguf .MODEL_ARCH .NOMIC_BERT
3568
3562
3569
- super ().__init__ (dir_model , ftype , fname_out , hf_arch , hparams = hparams , ** kwargs )
3563
+ super ().__init__ (dir_model , ftype , fname_out , hparams = hparams , ** kwargs )
3570
3564
3571
3565
self ._tokenizer_is_xlmroberta = self ._is_tokenizer_xlmroberta ()
3572
3566
if self ._tokenizer_is_xlmroberta :
@@ -5902,8 +5896,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
5902
5896
return n
5903
5897
5904
5898
5905
- def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5906
- hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5899
+ def get_model_architecture (hparams : dict [str , Any ], model_type : ModelType ) -> str :
5907
5900
text_config = hparams .get ("text_config" , {})
5908
5901
vision_config = hparams .get ("vision_config" , {})
5909
5902
arch = hparams ["architectures" ][0 ]
@@ -5974,15 +5967,16 @@ def main() -> None:
5974
5967
with torch .inference_mode ():
5975
5968
output_type = ftype_map [args .outtype ]
5976
5969
model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
5977
- model_architecture = get_model_architecture (dir_model , model_type )
5970
+ hparams = ModelBase .load_hparams (dir_model )
5971
+ model_architecture = get_model_architecture (hparams , model_type )
5978
5972
logger .info (f"Model architecture: { model_architecture } " )
5979
5973
try :
5980
5974
model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
5981
5975
except NotImplementedError :
5982
5976
logger .error (f"Model { model_architecture } is not supported" )
5983
5977
sys .exit (1 )
5984
5978
5985
- model_instance = model_class (dir_model , output_type , fname_out , model_architecture ,
5979
+ model_instance = model_class (dir_model , output_type , fname_out ,
5986
5980
is_big_endian = args .bigendian , use_temp_file = args .use_temp_file ,
5987
5981
eager = args .no_lazy ,
5988
5982
metadata_override = args .metadata , model_name = args .model_name ,
0 commit comments