Skip to content

Commit ce73722

Browse files
committed
set hf_arch in TextModel.__init__
1 parent 11683f5 commit ce73722

File tree

1 file changed

+22
-28
lines changed

1 file changed

+22
-28
lines changed

convert_hf_to_gguf.py

+22-28
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import re
1313
import sys
14+
from abc import ABC, abstractmethod
1415
from enum import IntEnum
1516
from pathlib import Path
1617
from hashlib import sha256
@@ -51,7 +52,7 @@ class ModelType(IntEnum):
5152
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
5253

5354

54-
class ModelBase:
55+
class ModelBase(ABC):
5556
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
5657
ModelType.TEXT: {},
5758
ModelType.VISION: {},
@@ -81,25 +82,11 @@ class ModelBase:
8182
block_count: int
8283
tensor_map: gguf.TensorNameMap
8384

84-
def __init__(
85-
self,
86-
dir_model : Path,
87-
ftype : gguf.LlamaFileType,
88-
fname_out : Path,
89-
hf_arch : str,
90-
*,
91-
is_big_endian : bool = False,
92-
use_temp_file : bool = False,
93-
eager : bool = False,
94-
metadata_override : Path | None = None,
95-
model_name : str | None = None,
96-
split_max_tensors : int = 0,
97-
split_max_size : int = 0,
98-
dry_run : bool = False,
99-
small_first_shard : bool = False,
100-
hparams : dict[str, Any] | None = None,
101-
remote_hf_model_id : str | None = None,
102-
):
85+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
86+
use_temp_file: bool = False, eager: bool = False,
87+
metadata_override: Path | None = None, model_name: str | None = None,
88+
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
89+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
10390
if type(self) is ModelBase or \
10491
type(self) is TextModel or \
10592
type(self) is VisionModel:
@@ -108,7 +95,6 @@ def __init__(
10895
self.dir_model = dir_model
10996
self.ftype = ftype
11097
self.fname_out = fname_out
111-
self.hf_arch = hf_arch
11298
self.is_big_endian = is_big_endian
11399
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
114100
self.use_temp_file = use_temp_file
@@ -151,6 +137,11 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
151137
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
152138
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
153139

140+
@property
141+
@abstractmethod
142+
def model_type(self):
143+
raise NotImplementedError
144+
154145
@classmethod
155146
def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path:
156147
stem, suffix = path.stem, path.suffix
@@ -468,8 +459,11 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
468459

469460

470461
class TextModel(ModelBase):
462+
model_type = ModelType.TEXT
463+
471464
def __init__(self, *args, **kwargs):
472465
super().__init__(*args, **kwargs)
466+
self.hf_arch = get_model_architecture(self.hparams, self.model_type)
473467

474468
if "text_config" in self.hparams:
475469
# move the text_config to the root level
@@ -1116,8 +1110,8 @@ def _try_set_pooling_type(self) -> None:
11161110

11171111

11181112
class VisionModel(ModelBase):
1113+
model_type = ModelType.VISION
11191114
model_arch = gguf.MODEL_ARCH.CLIP_VISION
1120-
n_text_embd = 0
11211115
preprocessor_config: dict[str, Any]
11221116
global_config: dict[str, Any]
11231117

@@ -3558,15 +3552,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35583552
class NomicBertModel(BertModel):
35593553
model_arch = gguf.MODEL_ARCH.BERT
35603554

3561-
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, hf_arch: str, **kwargs: Any):
3555+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
35623556
hparams = kwargs.pop("hparams", None)
35633557
if hparams is None:
35643558
hparams = ModelBase.load_hparams(dir_model)
35653559

35663560
self.is_moe = bool(hparams.get("moe_every_n_layers"))
35673561
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
35683562

3569-
super().__init__(dir_model, ftype, fname_out, hf_arch, hparams=hparams, **kwargs)
3563+
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
35703564

35713565
self._tokenizer_is_xlmroberta = self._is_tokenizer_xlmroberta()
35723566
if self._tokenizer_is_xlmroberta:
@@ -5902,8 +5896,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
59025896
return n
59035897

59045898

5905-
def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
5906-
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
5899+
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
59075900
text_config = hparams.get("text_config", {})
59085901
vision_config = hparams.get("vision_config", {})
59095902
arch = hparams["architectures"][0]
@@ -5974,15 +5967,16 @@ def main() -> None:
59745967
with torch.inference_mode():
59755968
output_type = ftype_map[args.outtype]
59765969
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
5977-
model_architecture = get_model_architecture(dir_model, model_type)
5970+
hparams = ModelBase.load_hparams(dir_model)
5971+
model_architecture = get_model_architecture(hparams, model_type)
59785972
logger.info(f"Model architecture: {model_architecture}")
59795973
try:
59805974
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
59815975
except NotImplementedError:
59825976
logger.error(f"Model {model_architecture} is not supported")
59835977
sys.exit(1)
59845978

5985-
model_instance = model_class(dir_model, output_type, fname_out, model_architecture,
5979+
model_instance = model_class(dir_model, output_type, fname_out,
59865980
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
59875981
eager=args.no_lazy,
59885982
metadata_override=args.metadata, model_name=args.model_name,

0 commit comments

Comments
 (0)