Skip to content

Commit 5f5e39e

Browse files
manyosocebtenzzre
andauthored
model : Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture (#12466)
* Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture - Adds MoE-based embedding model supporting multilingual embeddings. - Selects architecture variant based on hyperparameter detection (MoE layers). - Removes unnecessary subclass initialization checks for clarity. https://www.nomic.ai/blog/posts/nomic-embed-text-v2 Co-authored-by: Jared Van Bortel <[email protected]> * fix tokenizer * don't rename this tensor --------- Co-authored-by: Jared Van Bortel <[email protected]>
1 parent eaea325 commit 5f5e39e

9 files changed

+247
-110
lines changed

convert_hf_to_gguf.py

+135-92
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class ModelBase:
7878
# subclasses should define this!
7979
model_arch: gguf.MODEL_ARCH
8080

81-
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
81+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
8282
use_temp_file: bool = False, eager: bool = False,
8383
metadata_override: Path | None = None, model_name: str | None = None,
8484
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
@@ -454,13 +454,6 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454454

455455

456456
class TextModel(ModelBase):
457-
@classmethod
458-
def __init_subclass__(cls):
459-
# can't use an abstract property, because overriding it without type errors
460-
# would require using decorated functions instead of simply defining the property
461-
if "model_arch" not in cls.__dict__:
462-
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
463-
464457
def set_vocab(self):
465458
self._set_vocab_gpt2()
466459

@@ -3373,14 +3366,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33733366

33743367
return [(self.map_tensor_name(name), data_torch)]
33753368

3376-
3377-
@ModelBase.register("RobertaModel")
3378-
class RobertaModel(BertModel):
3379-
model_arch = gguf.MODEL_ARCH.BERT
3380-
3381-
def __init__(self, *args, **kwargs):
3382-
super().__init__(*args, **kwargs)
3383-
3369+
def _xlmroberta_tokenizer_init(self) -> None:
33843370
# we need the pad_token_id to know how to chop down position_embd matrix
33853371
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
33863372
self._position_offset = 1 + pad_token_id
@@ -3389,82 +3375,7 @@ def __init__(self, *args, **kwargs):
33893375
else:
33903376
self._position_offset = None
33913377

3392-
def set_vocab(self):
3393-
"""Support BPE tokenizers for roberta models"""
3394-
bpe_tok_path = self.dir_model / "tokenizer.json"
3395-
if bpe_tok_path.exists():
3396-
self._set_vocab_gpt2()
3397-
self.gguf_writer.add_add_bos_token(True)
3398-
self.gguf_writer.add_add_eos_token(True)
3399-
3400-
# we need this to validate the size of the token_type embeddings
3401-
# though currently we are passing all zeros to the token_type embeddings
3402-
# "Sequence A" or "Sequence B"
3403-
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
3404-
3405-
else:
3406-
return super().set_vocab()
3407-
3408-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3409-
# if name starts with "roberta.", remove the prefix
3410-
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
3411-
if name.startswith("roberta."):
3412-
name = name[8:]
3413-
3414-
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
3415-
if name == "embeddings.position_embeddings.weight":
3416-
if self._position_offset is not None:
3417-
data_torch = data_torch[self._position_offset:,:]
3418-
3419-
return super().modify_tensors(data_torch, name, bid)
3420-
3421-
3422-
@ModelBase.register("NomicBertModel")
3423-
class NomicBertModel(BertModel):
3424-
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
3425-
3426-
def __init__(self, *args, **kwargs):
3427-
super().__init__(*args, **kwargs)
3428-
3429-
# the HF config claims n_ctx=8192, but it uses RoPE scaling
3430-
self.hparams["n_ctx"] = 2048
3431-
3432-
# SwigLU activation
3433-
assert self.hparams["activation_function"] == "swiglu"
3434-
# this doesn't do anything in the HF version
3435-
assert self.hparams["causal"] is False
3436-
# no bias tensors
3437-
assert self.hparams["qkv_proj_bias"] is False
3438-
assert self.hparams["mlp_fc1_bias"] is False
3439-
assert self.hparams["mlp_fc2_bias"] is False
3440-
# norm at end of layer
3441-
assert self.hparams["prenorm"] is False
3442-
# standard RoPE
3443-
assert self.hparams["rotary_emb_fraction"] == 1.0
3444-
assert self.hparams["rotary_emb_interleaved"] is False
3445-
assert self.hparams["rotary_emb_scale_base"] is None
3446-
3447-
def set_gguf_parameters(self):
3448-
super().set_gguf_parameters()
3449-
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
3450-
3451-
3452-
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
3453-
class XLMRobertaModel(BertModel):
3454-
model_arch = gguf.MODEL_ARCH.BERT
3455-
3456-
def __init__(self, *args, **kwargs):
3457-
super().__init__(*args, **kwargs)
3458-
3459-
# we need the pad_token_id to know how to chop down position_embd matrix
3460-
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
3461-
self._position_offset = 1 + pad_token_id
3462-
if "max_position_embeddings" in self.hparams:
3463-
self.hparams["max_position_embeddings"] -= self._position_offset
3464-
else:
3465-
self._position_offset = None
3466-
3467-
def set_vocab(self):
3378+
def _xlmroberta_set_vocab(self) -> None:
34683379
# to avoid TypeError: Descriptors cannot be created directly
34693380
# exception when importing sentencepiece_model_pb2
34703381
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@@ -3546,6 +3457,138 @@ def set_vocab(self):
35463457
self.gguf_writer.add_add_bos_token(True)
35473458
self.gguf_writer.add_add_eos_token(True)
35483459

3460+
3461+
@ModelBase.register("RobertaModel")
3462+
class RobertaModel(BertModel):
3463+
model_arch = gguf.MODEL_ARCH.BERT
3464+
3465+
def __init__(self, *args, **kwargs):
3466+
super().__init__(*args, **kwargs)
3467+
3468+
# we need the pad_token_id to know how to chop down position_embd matrix
3469+
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
3470+
self._position_offset = 1 + pad_token_id
3471+
if "max_position_embeddings" in self.hparams:
3472+
self.hparams["max_position_embeddings"] -= self._position_offset
3473+
else:
3474+
self._position_offset = None
3475+
3476+
def set_vocab(self):
3477+
"""Support BPE tokenizers for roberta models"""
3478+
bpe_tok_path = self.dir_model / "tokenizer.json"
3479+
if bpe_tok_path.exists():
3480+
self._set_vocab_gpt2()
3481+
self.gguf_writer.add_add_bos_token(True)
3482+
self.gguf_writer.add_add_eos_token(True)
3483+
3484+
# we need this to validate the size of the token_type embeddings
3485+
# though currently we are passing all zeros to the token_type embeddings
3486+
# "Sequence A" or "Sequence B"
3487+
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
3488+
3489+
else:
3490+
return super().set_vocab()
3491+
3492+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3493+
# if name starts with "roberta.", remove the prefix
3494+
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
3495+
if name.startswith("roberta."):
3496+
name = name[8:]
3497+
3498+
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
3499+
if name == "embeddings.position_embeddings.weight":
3500+
if self._position_offset is not None:
3501+
data_torch = data_torch[self._position_offset:,:]
3502+
3503+
return super().modify_tensors(data_torch, name, bid)
3504+
3505+
3506+
@ModelBase.register("NomicBertModel")
3507+
class NomicBertModel(BertModel):
3508+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
3509+
hparams = kwargs.pop("hparams", None)
3510+
if hparams is None:
3511+
hparams = ModelBase.load_hparams(dir_model)
3512+
3513+
self.is_moe = bool(hparams.get("moe_every_n_layers"))
3514+
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
3515+
3516+
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
3517+
3518+
self._tokenizer_is_xlmroberta = self._is_tokenizer_xlmroberta()
3519+
if self._tokenizer_is_xlmroberta:
3520+
self._xlmroberta_tokenizer_init()
3521+
3522+
# the HF config claims n_ctx=8192, but it uses RoPE scaling
3523+
self.hparams["n_ctx"] = 2048
3524+
3525+
assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu"
3526+
3527+
# this doesn't do anything in the HF version
3528+
assert self.hparams["causal"] is False
3529+
# no bias tensors unless MoE
3530+
assert self.hparams["qkv_proj_bias"] == self.is_moe
3531+
assert self.hparams["mlp_fc1_bias"] == self.is_moe
3532+
assert self.hparams["mlp_fc2_bias"] == self.is_moe
3533+
3534+
# norm at end of layer
3535+
assert self.hparams["prenorm"] is False
3536+
# standard RoPE
3537+
assert self.hparams["rotary_emb_fraction"] == 1.0
3538+
assert self.hparams["rotary_emb_interleaved"] is False
3539+
assert self.hparams["rotary_emb_scale_base"] is None
3540+
3541+
def set_vocab(self) -> None:
3542+
if self._tokenizer_is_xlmroberta:
3543+
return self._xlmroberta_set_vocab()
3544+
return super().set_vocab()
3545+
3546+
def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]:
3547+
# If the tensor is an experts bias tensor, skip it by returning an empty list.
3548+
if "mlp.experts.bias" in name:
3549+
return [] # Explicitly return an empty list.
3550+
3551+
if "mlp.experts.mlp.w1" in name:
3552+
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
3553+
name += ".weight"
3554+
3555+
if "mlp.experts.mlp.w2" in name:
3556+
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
3557+
data_torch = data_torch.transpose(1, 2)
3558+
name += ".weight"
3559+
3560+
return [(self.map_tensor_name(name), data_torch)]
3561+
3562+
def set_gguf_parameters(self):
3563+
super().set_gguf_parameters()
3564+
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
3565+
if self.is_moe:
3566+
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
3567+
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
3568+
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
3569+
3570+
def _is_tokenizer_xlmroberta(self) -> bool:
3571+
with open(self.dir_model / "tokenizer.json") as f:
3572+
tokenizer_json = json.load(f)
3573+
toktyp = tokenizer_json["model"]["type"]
3574+
if toktyp == "Unigram":
3575+
return True
3576+
if toktyp == "WordPiece":
3577+
return False
3578+
raise ValueError(f"unknown tokenizer: {toktyp}")
3579+
3580+
3581+
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
3582+
class XLMRobertaModel(BertModel):
3583+
model_arch = gguf.MODEL_ARCH.BERT
3584+
3585+
def __init__(self, *args, **kwargs):
3586+
super().__init__(*args, **kwargs)
3587+
self._xlmroberta_tokenizer_init()
3588+
3589+
def set_vocab(self):
3590+
self._xlmroberta_set_vocab()
3591+
35493592
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
35503593
# if name starts with "roberta.", remove the prefix
35513594
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main

gguf-py/gguf/constants.py

+19
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class LLM:
104104
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
105105
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
106106
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
107+
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
107108
POOLING_TYPE = "{arch}.pooling_type"
108109
LOGIT_SCALE = "{arch}.logit_scale"
109110
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
@@ -267,6 +268,7 @@ class MODEL_ARCH(IntEnum):
267268
REFACT = auto()
268269
BERT = auto()
269270
NOMIC_BERT = auto()
271+
NOMIC_BERT_MOE = auto()
270272
JINA_BERT_V2 = auto()
271273
BLOOM = auto()
272274
STABLELM = auto()
@@ -521,6 +523,7 @@ class MODEL_TENSOR(IntEnum):
521523
MODEL_ARCH.REFACT: "refact",
522524
MODEL_ARCH.BERT: "bert",
523525
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
526+
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
524527
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
525528
MODEL_ARCH.BLOOM: "bloom",
526529
MODEL_ARCH.STABLELM: "stablelm",
@@ -960,6 +963,22 @@ class MODEL_TENSOR(IntEnum):
960963
MODEL_TENSOR.FFN_UP,
961964
MODEL_TENSOR.LAYER_OUT_NORM,
962965
],
966+
MODEL_ARCH.NOMIC_BERT_MOE: [
967+
MODEL_TENSOR.TOKEN_EMBD,
968+
MODEL_TENSOR.TOKEN_EMBD_NORM,
969+
MODEL_TENSOR.TOKEN_TYPES,
970+
MODEL_TENSOR.POS_EMBD,
971+
MODEL_TENSOR.OUTPUT_NORM,
972+
MODEL_TENSOR.ATTN_OUT_NORM,
973+
MODEL_TENSOR.ATTN_QKV,
974+
MODEL_TENSOR.ATTN_OUT,
975+
MODEL_TENSOR.FFN_DOWN,
976+
MODEL_TENSOR.FFN_UP,
977+
MODEL_TENSOR.FFN_GATE_INP,
978+
MODEL_TENSOR.FFN_DOWN_EXP,
979+
MODEL_TENSOR.FFN_UP_EXP,
980+
MODEL_TENSOR.LAYER_OUT_NORM,
981+
],
963982
MODEL_ARCH.JINA_BERT_V2: [
964983
MODEL_TENSOR.TOKEN_EMBD,
965984
MODEL_TENSOR.TOKEN_EMBD_NORM,

gguf-py/gguf/gguf_writer.py

+3
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,9 @@ def add_expert_weights_norm(self, value: bool) -> None:
728728
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
729729
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
730730

731+
def add_moe_every_n_layers(self, value: int) -> None:
732+
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
733+
731734
def add_swin_norm(self, value: bool) -> None:
732735
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
733736

gguf-py/gguf/tensor_mapping.py

+4
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ class TensorNameMap:
290290
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
291291
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
292292
"language_model.model.layers.{bid}.feed_forward.router", # llama4
293+
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
293294
),
294295

295296
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -322,6 +323,7 @@ class TensorNameMap:
322323
"model.layers.layers.{bid}.mlp.up_proj", # plamo
323324
"model.layers.{bid}.feed_forward.w3", # internlm2
324325
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
326+
"encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
325327
"model.layers.{bid}.mlp.c_fc", # starcoder2
326328
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
327329
"model.layers.{bid}.residual_mlp.w3", # arctic
@@ -337,6 +339,7 @@ class TensorNameMap:
337339
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
338340
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
339341
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
342+
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
340343
),
341344

342345
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -418,6 +421,7 @@ class TensorNameMap:
418421
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
419422
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
420423
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
424+
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
421425
),
422426

423427
MODEL_TENSOR.FFN_DOWN_SHEXP: (

src/llama-arch.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
1919
{ LLM_ARCH_REFACT, "refact" },
2020
{ LLM_ARCH_BERT, "bert" },
2121
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
22+
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
2223
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
2324
{ LLM_ARCH_BLOOM, "bloom" },
2425
{ LLM_ARCH_STABLELM, "stablelm" },
@@ -106,6 +107,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
106107
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
107108
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
108109
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
110+
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
109111
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
110112
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
111113
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -472,6 +474,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
472474
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
473475
},
474476
},
477+
{
478+
LLM_ARCH_NOMIC_BERT_MOE,
479+
{
480+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
481+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
482+
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
483+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
484+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
485+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
486+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
487+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
488+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
489+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
490+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
491+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
492+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
493+
},
494+
},
475495
{
476496
LLM_ARCH_JINA_BERT_V2,
477497
{

0 commit comments

Comments
 (0)