Skip to content

Commit 8846aac

Browse files
authored
model : gemma3n text-only (ggml-org#14400)
* gemma3n * add llm_graph_input_one
1 parent a01047b commit 8846aac

13 files changed

+960
-15
lines changed

convert_hf_to_gguf.py

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def prepare_tensors(self):
310310
gguf.MODEL_TENSOR.POSNET_NORM2,
311311
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
312312
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
313+
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
314+
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
313315
)
314316
)
315317
or not new_name.endswith(".weight")
@@ -320,7 +322,11 @@ def prepare_tensors(self):
320322
self.match_model_tensor_name(new_name, key, bid)
321323
for key in (
322324
gguf.MODEL_TENSOR.TOKEN_EMBD,
325+
gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
323326
gguf.MODEL_TENSOR.OUTPUT,
327+
gguf.MODEL_TENSOR.ALTUP_ROUTER,
328+
gguf.MODEL_TENSOR.LAUREL_L,
329+
gguf.MODEL_TENSOR.LAUREL_R,
324330
)
325331
):
326332
if self.ftype in (
@@ -921,13 +927,16 @@ def _create_vocab_sentencepiece(self):
921927
tokenizer = SentencePieceProcessor()
922928
tokenizer.LoadFromFile(str(tokenizer_path))
923929

924-
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
930+
vocab_size = self.find_hparam([
931+
"vocab_size_per_layer_input", # gemma3n
932+
"vocab_size",
933+
], optional=True) or tokenizer.vocab_size()
925934

926935
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
927936
scores: list[float] = [-10000.0] * vocab_size
928937
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
929938

930-
for token_id in range(tokenizer.vocab_size()):
939+
for token_id in range(vocab_size):
931940
piece = tokenizer.IdToPiece(token_id)
932941
text = piece.encode("utf-8")
933942
score = tokenizer.GetScore(token_id)
@@ -942,6 +951,10 @@ def _create_vocab_sentencepiece(self):
942951
elif tokenizer.IsByte(token_id):
943952
toktype = SentencePieceTokenTypes.BYTE
944953

954+
if token_id >= vocab_size:
955+
logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
956+
break
957+
945958
tokens[token_id] = text
946959
scores[token_id] = score
947960
toktypes[token_id] = toktype
@@ -4217,6 +4230,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42174230
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
42184231
class Gemma3Model(TextModel):
42194232
model_arch = gguf.MODEL_ARCH.GEMMA3
4233+
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
42204234

42214235
def set_vocab(self):
42224236
self._set_vocab_sentencepiece()
@@ -4238,9 +4252,8 @@ def set_gguf_parameters(self):
42384252
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
42394253
self.gguf_writer.add_file_type(self.ftype)
42404254
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
4241-
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
4255+
# attn_logit_softcapping is removed in Gemma3
42424256
assert hparams.get("attn_logit_softcapping") is None
4243-
assert hparams.get("final_logit_softcapping") is None
42444257
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
42454258
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
42464259
if hparams.get("rope_scaling") is not None:
@@ -4252,7 +4265,7 @@ def set_gguf_parameters(self):
42524265
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
42534266
del bid # unused
42544267

4255-
if name.startswith("language_model."):
4268+
if "language_model." in name:
42564269
name = name.replace("language_model.", "")
42574270

42584271
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
@@ -4267,8 +4280,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42674280

42684281
# ref code in Gemma3RMSNorm
42694282
# output = output * (1.0 + self.weight.float())
4283+
# note: this is not the case on gemma3n
42704284
if name.endswith("norm.weight"):
4271-
data_torch = data_torch + 1
4285+
data_torch = data_torch + self.norm_shift
42724286

42734287
return [(self.map_tensor_name(name), data_torch)]
42744288

@@ -4325,6 +4339,104 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43254339
return [] # skip other tensors
43264340

43274341

4342+
@ModelBase.register("Gemma3nForConditionalGeneration")
4343+
class Gemma3NModel(Gemma3Model):
4344+
model_arch = gguf.MODEL_ARCH.GEMMA3N
4345+
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
4346+
4347+
_altup_proj: list[Tensor] = []
4348+
_altup_unembd: list[Tensor] = []
4349+
4350+
def __init__(self, *args, **kwargs):
4351+
super().__init__(*args, **kwargs)
4352+
assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
4353+
self._altup_proj = [
4354+
torch.Tensor(), # to be replaced
4355+
torch.Tensor(), # to be replaced
4356+
torch.Tensor(), # to be replaced
4357+
]
4358+
self._altup_unembd = [
4359+
torch.Tensor(), # to be replaced
4360+
torch.Tensor(), # to be replaced
4361+
torch.Tensor(), # to be replaced
4362+
]
4363+
4364+
def set_vocab(self):
4365+
with open(self.dir_model / "chat_template.jinja") as f:
4366+
# quick hack to make sure chat template is added
4367+
self.gguf_writer.add_chat_template(f.read())
4368+
super().set_vocab()
4369+
4370+
def set_gguf_parameters(self):
4371+
super().set_gguf_parameters()
4372+
self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
4373+
self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
4374+
self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
4375+
self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])
4376+
4377+
activation_sparsity_scale = []
4378+
for s in self.hparams["activation_sparsity_pattern"]:
4379+
normal_dist = torch.distributions.normal.Normal(0, 1)
4380+
std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
4381+
activation_sparsity_scale.append(std_multiplier.item())
4382+
self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)
4383+
4384+
sliding_window_pattern = []
4385+
for t in self.hparams["layer_types"]:
4386+
sliding_window_pattern.append(t == "sliding_attention")
4387+
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
4388+
4389+
def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
4390+
has_all = all(m.numel() > 0 for m in matrices)
4391+
if not has_all:
4392+
return None
4393+
else:
4394+
return torch.stack(matrices, dim=0)
4395+
4396+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4397+
if name.endswith("_scale"):
4398+
name = name + ".weight"
4399+
4400+
# TODO: implement self.prediction_coefs.weight.clamp_(...)
4401+
4402+
if "language_model." not in name:
4403+
return [] # skip non-language model tensors
4404+
4405+
if "altup_unembed_projections" in name:
4406+
data_torch = data_torch.to(device="cpu")
4407+
if ".0." in name:
4408+
self._altup_unembd[0] = data_torch
4409+
elif ".1." in name:
4410+
self._altup_unembd[1] = data_torch
4411+
elif ".2." in name:
4412+
self._altup_unembd[2] = data_torch
4413+
else:
4414+
raise ValueError(f"Unknown name: {name}")
4415+
out = self._stack_matrices(self._altup_unembd)
4416+
if out is not None:
4417+
return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
4418+
else:
4419+
return []
4420+
4421+
if "altup_projections" in name:
4422+
data_torch = data_torch.to(device="cpu")
4423+
if ".0." in name:
4424+
self._altup_proj[0] = data_torch
4425+
elif ".1." in name:
4426+
self._altup_proj[1] = data_torch
4427+
elif ".2." in name:
4428+
self._altup_proj[2] = data_torch
4429+
else:
4430+
raise ValueError(f"Unknown name: {name}")
4431+
out = self._stack_matrices(self._altup_proj)
4432+
if out is not None:
4433+
return [(self.map_tensor_name("model.altup_projections.weight"), out)]
4434+
else:
4435+
return []
4436+
4437+
return super().modify_tensors(data_torch, name, bid)
4438+
4439+
43284440
@ModelBase.register("Starcoder2ForCausalLM")
43294441
class StarCoder2Model(TextModel):
43304442
model_arch = gguf.MODEL_ARCH.STARCODER2

gguf-py/gguf/constants.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ class LLM:
118118
EMBEDDING_SCALE = "{arch}.embedding_scale"
119119
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
120120
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
121+
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
122+
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
123+
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
124+
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
121125

122126
class Attention:
123127
HEAD_COUNT = "{arch}.attention.head_count"
@@ -142,6 +146,8 @@ class Attention:
142146
SCALE = "{arch}.attention.scale"
143147
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
144148
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
149+
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
150+
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
145151

146152
class Rope:
147153
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -314,6 +320,7 @@ class MODEL_ARCH(IntEnum):
314320
GEMMA = auto()
315321
GEMMA2 = auto()
316322
GEMMA3 = auto()
323+
GEMMA3N = auto()
317324
STARCODER2 = auto()
318325
RWKV6 = auto()
319326
RWKV6QWEN2 = auto()
@@ -399,6 +406,22 @@ class MODEL_TENSOR(IntEnum):
399406
ATTN_Q_NORM = auto()
400407
ATTN_K_NORM = auto()
401408
LAYER_OUT_NORM = auto()
409+
PER_LAYER_TOKEN_EMBD = auto() # gemma3n
410+
PER_LAYER_MODEL_PROJ = auto() # gemma3n
411+
PER_LAYER_INP_GATE = auto() # gemma3n
412+
PER_LAYER_PROJ = auto() # gemma3n
413+
PER_LAYER_PROJ_NORM = auto() # gemma3n
414+
PER_LAYER_POST_NORM = auto() # gemma3n
415+
ALTUP_PROJ = auto() # gemma3n
416+
ALTUP_UNEMBD_PROJ = auto() # gemma3n
417+
ALTUP_CORRECT_COEF = auto() # gemma3n
418+
ALTUP_CORRECT_SCALE = auto() # gemma3n
419+
ALTUP_PREDICT_COEF = auto() # gemma3n
420+
ALTUP_ROUTER = auto() # gemma3n
421+
ALTUP_ROUTER_NORM = auto() # gemma3n
422+
LAUREL_L = auto() # gemma3n
423+
LAUREL_R = auto() # gemma3n
424+
LAUREL_POST_NORM = auto() # gemma3n
402425
SSM_IN = auto()
403426
SSM_CONV1D = auto()
404427
SSM_X = auto()
@@ -597,6 +620,7 @@ class MODEL_TENSOR(IntEnum):
597620
MODEL_ARCH.GEMMA: "gemma",
598621
MODEL_ARCH.GEMMA2: "gemma2",
599622
MODEL_ARCH.GEMMA3: "gemma3",
623+
MODEL_ARCH.GEMMA3N: "gemma3n",
600624
MODEL_ARCH.STARCODER2: "starcoder2",
601625
MODEL_ARCH.RWKV6: "rwkv6",
602626
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
@@ -682,6 +706,22 @@ class MODEL_TENSOR(IntEnum):
682706
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
683707
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
684708
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
709+
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
710+
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
711+
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n
712+
MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj", # gemma3n
713+
MODEL_TENSOR.ALTUP_PROJ: "altup_proj", # gemma3n
714+
MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n
715+
MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n
716+
MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n
717+
MODEL_TENSOR.ALTUP_CORRECT_COEF: "blk.{bid}.altup_correct_coef", # gemma3n
718+
MODEL_TENSOR.ALTUP_CORRECT_SCALE: "blk.{bid}.altup_correct_scale", # gemma3n
719+
MODEL_TENSOR.ALTUP_PREDICT_COEF: "blk.{bid}.altup_predict_coef", # gemma3n
720+
MODEL_TENSOR.ALTUP_ROUTER: "blk.{bid}.altup_router", # gemma3n
721+
MODEL_TENSOR.ALTUP_ROUTER_NORM: "blk.{bid}.altup_router_norm", # gemma3n
722+
MODEL_TENSOR.LAUREL_L: "blk.{bid}.laurel_l", # gemma3n
723+
MODEL_TENSOR.LAUREL_R: "blk.{bid}.laurel_r", # gemma3n
724+
MODEL_TENSOR.LAUREL_POST_NORM: "blk.{bid}.laurel_post_norm", # gemma3n
685725
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
686726
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
687727
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
@@ -1486,6 +1526,41 @@ class MODEL_TENSOR(IntEnum):
14861526
MODEL_TENSOR.FFN_PRE_NORM,
14871527
MODEL_TENSOR.FFN_POST_NORM,
14881528
],
1529+
MODEL_ARCH.GEMMA3N: [
1530+
MODEL_TENSOR.TOKEN_EMBD,
1531+
MODEL_TENSOR.OUTPUT,
1532+
MODEL_TENSOR.OUTPUT_NORM,
1533+
MODEL_TENSOR.ATTN_Q,
1534+
MODEL_TENSOR.ATTN_Q_NORM,
1535+
MODEL_TENSOR.ATTN_K,
1536+
MODEL_TENSOR.ATTN_K_NORM,
1537+
MODEL_TENSOR.ATTN_V,
1538+
MODEL_TENSOR.ATTN_OUT,
1539+
MODEL_TENSOR.FFN_GATE,
1540+
MODEL_TENSOR.FFN_DOWN,
1541+
MODEL_TENSOR.FFN_UP,
1542+
MODEL_TENSOR.ATTN_NORM,
1543+
MODEL_TENSOR.ATTN_POST_NORM,
1544+
MODEL_TENSOR.FFN_PRE_NORM,
1545+
MODEL_TENSOR.FFN_POST_NORM,
1546+
# altup / laurel
1547+
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
1548+
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
1549+
MODEL_TENSOR.PER_LAYER_INP_GATE,
1550+
MODEL_TENSOR.PER_LAYER_PROJ,
1551+
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
1552+
MODEL_TENSOR.PER_LAYER_POST_NORM,
1553+
MODEL_TENSOR.ALTUP_PROJ,
1554+
MODEL_TENSOR.ALTUP_UNEMBD_PROJ,
1555+
MODEL_TENSOR.ALTUP_CORRECT_COEF,
1556+
MODEL_TENSOR.ALTUP_CORRECT_SCALE,
1557+
MODEL_TENSOR.ALTUP_PREDICT_COEF,
1558+
MODEL_TENSOR.ALTUP_ROUTER,
1559+
MODEL_TENSOR.ALTUP_ROUTER_NORM,
1560+
MODEL_TENSOR.LAUREL_L,
1561+
MODEL_TENSOR.LAUREL_R,
1562+
MODEL_TENSOR.LAUREL_POST_NORM,
1563+
],
14891564
MODEL_ARCH.STARCODER2: [
14901565
MODEL_TENSOR.TOKEN_EMBD,
14911566
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,18 @@ def add_parallel_residual(self, use: bool) -> None:
672672
def add_decoder_start_token_id(self, id: int) -> None:
673673
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
674674

675+
def add_embedding_length_per_layer_input(self, value: int) -> None:
676+
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
677+
678+
def add_altup_active_idx(self, val: int) -> None:
679+
self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
680+
681+
def add_altup_num_inputs(self, val: int) -> None:
682+
self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
683+
684+
def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
685+
self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
686+
675687
def add_head_count(self, count: int | Sequence[int]) -> None:
676688
if isinstance(count, int):
677689
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
@@ -702,6 +714,12 @@ def add_max_alibi_bias(self, bias: float) -> None:
702714
def add_clamp_kqv(self, value: float) -> None:
703715
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
704716

717+
def add_shared_kv_layers(self, value: float) -> None:
718+
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
719+
720+
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
721+
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
722+
705723
def add_logit_scale(self, value: float) -> None:
706724
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
707725

0 commit comments

Comments
 (0)