Skip to content

Commit 566c16f

Browse files
authored
model : add support for ERNIE 4.5 0.3B model (#14408)
Add Day-0 support for Baidu ERNIE 4.5 0.3B model. Signed-off-by: Weizhao Ouyang <[email protected]>
1 parent b25e927 commit 566c16f

File tree

6 files changed

+260
-0
lines changed

6 files changed

+260
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,52 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27432743
yield from super().modify_tensors(data_torch, name, bid)
27442744

27452745

2746+
@ModelBase.register("Ernie4_5_ForCausalLM")
2747+
class Ernie4_5Model(TextModel):
2748+
model_arch = gguf.MODEL_ARCH.ERNIE4_5
2749+
2750+
def set_vocab(self):
2751+
self._set_vocab_sentencepiece()
2752+
2753+
def set_gguf_parameters(self):
2754+
super().set_gguf_parameters()
2755+
2756+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2757+
num_heads = self.hparams["num_attention_heads"]
2758+
num_kv_heads = self.hparams["num_key_value_heads"]
2759+
head_dim = self.hparams["head_dim"]
2760+
2761+
if "ernie." in name:
2762+
name = name.replace("ernie.", "model.")
2763+
# split the qkv weights
2764+
# qkv_proj shape: [(num_heads + 2 * num_kv_heads) * head_dim, hidden_size]
2765+
if "qkv_proj" in name:
2766+
name_q = name.replace("qkv_proj.weight", "q_proj.weight")
2767+
name_k = name.replace("qkv_proj.weight", "k_proj.weight")
2768+
name_v = name.replace("qkv_proj.weight", "v_proj.weight")
2769+
total_q_dim = num_heads * head_dim
2770+
total_k_dim = num_kv_heads * head_dim
2771+
total_v_dim = num_kv_heads * head_dim
2772+
q_proj_weight, k_proj_weight, v_proj_weight = data_torch.split([total_q_dim, total_k_dim, total_v_dim], dim=0)
2773+
return [
2774+
(self.map_tensor_name(name_q), q_proj_weight),
2775+
(self.map_tensor_name(name_k), k_proj_weight),
2776+
(self.map_tensor_name(name_v), v_proj_weight)
2777+
]
2778+
# split the up_gate_proj into gate and up
2779+
# up_gate_proj shape: [2 * intermediate_size, hidden_size]
2780+
if "up_gate_proj" in name:
2781+
name_up = name.replace("up_gate_proj.weight", "up_proj.weight")
2782+
name_gate = name.replace("up_gate_proj.weight", "gate_proj.weight")
2783+
dim_half = data_torch.shape[0] // 2
2784+
gate_proj_weight, up_proj_weight = data_torch.split(dim_half, dim=0)
2785+
return [
2786+
(self.map_tensor_name(name_gate), gate_proj_weight),
2787+
(self.map_tensor_name(name_up), up_proj_weight)
2788+
]
2789+
return [(self.map_tensor_name(name), data_torch)]
2790+
2791+
27462792
@ModelBase.register(
27472793
"Qwen2VLModel",
27482794
"Qwen2VLForConditionalGeneration",

gguf-py/gguf/constants.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum):
354354
BAILINGMOE = auto()
355355
DOTS1 = auto()
356356
ARCEE = auto()
357+
ERNIE4_5 = auto()
357358

358359

359360
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -654,6 +655,7 @@ class MODEL_TENSOR(IntEnum):
654655
MODEL_ARCH.BAILINGMOE: "bailingmoe",
655656
MODEL_ARCH.DOTS1: "dots1",
656657
MODEL_ARCH.ARCEE: "arcee",
658+
MODEL_ARCH.ERNIE4_5: "ernie4_5",
657659
}
658660

659661
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2177,6 +2179,20 @@ class MODEL_TENSOR(IntEnum):
21772179
MODEL_TENSOR.FFN_DOWN,
21782180
MODEL_TENSOR.FFN_UP,
21792181
],
2182+
MODEL_ARCH.ERNIE4_5: [
2183+
MODEL_TENSOR.TOKEN_EMBD,
2184+
MODEL_TENSOR.OUTPUT_NORM,
2185+
MODEL_TENSOR.OUTPUT,
2186+
MODEL_TENSOR.ATTN_NORM,
2187+
MODEL_TENSOR.ATTN_Q,
2188+
MODEL_TENSOR.ATTN_K,
2189+
MODEL_TENSOR.ATTN_V,
2190+
MODEL_TENSOR.ATTN_OUT,
2191+
MODEL_TENSOR.FFN_NORM,
2192+
MODEL_TENSOR.FFN_GATE,
2193+
MODEL_TENSOR.FFN_DOWN,
2194+
MODEL_TENSOR.FFN_UP,
2195+
],
21802196
# TODO
21812197
}
21822198

src/llama-arch.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7676
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
7777
{ LLM_ARCH_DOTS1, "dots1" },
7878
{ LLM_ARCH_ARCEE, "arcee" },
79+
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
7980
{ LLM_ARCH_UNKNOWN, "(unknown)" },
8081
};
8182

@@ -1658,6 +1659,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
16581659
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
16591660
}
16601661
},
1662+
{
1663+
LLM_ARCH_ERNIE4_5,
1664+
{
1665+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1666+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1667+
{ LLM_TENSOR_OUTPUT, "output" },
1668+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1669+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1670+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1671+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1672+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1673+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1674+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1675+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1676+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1677+
},
1678+
},
16611679
{
16621680
LLM_ARCH_UNKNOWN,
16631681
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ enum llm_arch {
8080
LLM_ARCH_BAILINGMOE,
8181
LLM_ARCH_DOTS1,
8282
LLM_ARCH_ARCEE,
83+
LLM_ARCH_ERNIE4_5,
8384
LLM_ARCH_UNKNOWN,
8485
};
8586

src/llama-model.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
4747
case LLM_TYPE_475M: return "475M";
4848
case LLM_TYPE_770M: return "770M";
4949
case LLM_TYPE_780M: return "780M";
50+
case LLM_TYPE_0_3B: return "0.3B";
5051
case LLM_TYPE_0_5B: return "0.5B";
5152
case LLM_TYPE_0_6B: return "0.6B";
5253
case LLM_TYPE_1B: return "1B";
@@ -1504,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15041505
default: type = LLM_TYPE_UNKNOWN;
15051506
}
15061507
} break;
1508+
case LLM_ARCH_ERNIE4_5:
1509+
{
1510+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1511+
switch (hparams.n_layer) {
1512+
case 18: type = LLM_TYPE_0_3B; break;
1513+
default: type = LLM_TYPE_UNKNOWN;
1514+
}
1515+
} break;
15071516
default: throw std::runtime_error("unsupported model architecture");
15081517
}
15091518

@@ -4344,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
43444353

43454354
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
43464355

4356+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4357+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4358+
}
4359+
} break;
4360+
case LLM_ARCH_ERNIE4_5:
4361+
{
4362+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4363+
4364+
// output
4365+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4366+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4367+
// if output is NULL, init from the input tok embed
4368+
if (output == NULL) {
4369+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4370+
}
4371+
4372+
for (int i = 0; i < n_layer; ++i) {
4373+
auto & layer = layers[i];
4374+
4375+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4376+
4377+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4378+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
4379+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
4380+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4381+
4382+
// optional bias tensors
4383+
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4384+
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4385+
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4386+
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4387+
4388+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4389+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
43474390
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
43484391
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
43494392
}
@@ -14125,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
1412514168
}
1412614169
};
1412714170

14171+
struct llm_build_ernie4_5 : public llm_graph_context {
14172+
llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14173+
const int64_t n_embd_head = hparams.n_embd_head_v;
14174+
14175+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14176+
GGML_ASSERT(n_embd_head == hparams.n_rot);
14177+
14178+
ggml_tensor * cur;
14179+
ggml_tensor * inpL;
14180+
14181+
inpL = build_inp_embd(model.tok_embd);
14182+
14183+
// inp_pos - contains the positions
14184+
ggml_tensor * inp_pos = build_inp_pos();
14185+
14186+
auto * inp_attn = build_attn_inp_kv_unified();
14187+
14188+
for (int il = 0; il < n_layer; ++il) {
14189+
ggml_tensor * inpSA = inpL;
14190+
14191+
// norm
14192+
{
14193+
cur = build_norm(inpL,
14194+
model.layers[il].attn_norm, NULL,
14195+
LLM_NORM_RMS, il);
14196+
cb(cur, "attn_norm", il);
14197+
}
14198+
14199+
// self-attention
14200+
{
14201+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14202+
cb(Qcur, "Qcur", il);
14203+
if (model.layers[il].bq) {
14204+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14205+
cb(Qcur, "Qcur", il);
14206+
}
14207+
14208+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14209+
cb(Kcur, "Kcur", il);
14210+
if (model.layers[il].bk) {
14211+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14212+
cb(Kcur, "Kcur", il);
14213+
}
14214+
14215+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14216+
cb(Vcur, "Vcur", il);
14217+
if (model.layers[il].bv) {
14218+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14219+
cb(Vcur, "Vcur", il);
14220+
}
14221+
14222+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14223+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14224+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14225+
14226+
Qcur = ggml_rope_ext(
14227+
ctx0, Qcur, inp_pos, nullptr,
14228+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14229+
ext_factor, attn_factor, beta_fast, beta_slow
14230+
);
14231+
14232+
Kcur = ggml_rope_ext(
14233+
ctx0, Kcur, inp_pos, nullptr,
14234+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14235+
ext_factor, attn_factor, beta_fast, beta_slow
14236+
);
14237+
14238+
cb(Qcur, "Qcur", il);
14239+
cb(Kcur, "Kcur", il);
14240+
cb(Vcur, "Vcur", il);
14241+
14242+
cur = build_attn(inp_attn, gf,
14243+
model.layers[il].wo, NULL,
14244+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14245+
}
14246+
14247+
if (il == n_layer - 1) {
14248+
// skip computing output for unused tokens
14249+
ggml_tensor * inp_out_ids = build_inp_out_ids();
14250+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14251+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14252+
}
14253+
14254+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14255+
cb(ffn_inp, "ffn_inp", il);
14256+
14257+
// feed-forward network
14258+
{
14259+
cur = build_norm(ffn_inp,
14260+
model.layers[il].ffn_norm, NULL,
14261+
LLM_NORM_RMS, il);
14262+
cb(cur, "ffn_norm", il);
14263+
14264+
cur = build_ffn(cur,
14265+
model.layers[il].ffn_up, NULL, NULL,
14266+
model.layers[il].ffn_gate, NULL, NULL,
14267+
model.layers[il].ffn_down, NULL, NULL,
14268+
NULL,
14269+
LLM_FFN_SILU, LLM_FFN_PAR, il);
14270+
cb(cur, "ffn_out", il);
14271+
}
14272+
14273+
cur = ggml_add(ctx0, cur, ffn_inp);
14274+
14275+
cur = build_cvec(cur, il);
14276+
cb(cur, "l_out", il);
14277+
14278+
// input for next layer
14279+
inpL = cur;
14280+
}
14281+
14282+
cur = inpL;
14283+
14284+
cur = build_norm(cur,
14285+
model.output_norm, NULL,
14286+
LLM_NORM_RMS, -1);
14287+
14288+
cb(cur, "result_norm", -1);
14289+
res->t_embd = cur;
14290+
14291+
// lm_head
14292+
cur = build_lora_mm(model.output, cur);
14293+
14294+
cb(cur, "result_output", -1);
14295+
res->t_logits = cur;
14296+
14297+
ggml_build_forward_expand(gf, cur);
14298+
}
14299+
};
14300+
1412814301
struct llm_build_arcee : public llm_graph_context {
1412914302
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
1413014303
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14635,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
1463514808
{
1463614809
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
1463714810
} break;
14811+
case LLM_ARCH_ERNIE4_5:
14812+
{
14813+
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
14814+
} break;
1463814815
default:
1463914816
GGML_ABORT("fatal error");
1464014817
}
@@ -14786,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1478614963
case LLM_ARCH_BAILINGMOE:
1478714964
case LLM_ARCH_NEO_BERT:
1478814965
case LLM_ARCH_ARCEE:
14966+
case LLM_ARCH_ERNIE4_5:
1478914967
return LLAMA_ROPE_TYPE_NORM;
1479014968

1479114969
// the pairs of head values are offset by n_rot/2

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ enum llm_type {
3939
LLM_TYPE_475M,
4040
LLM_TYPE_770M,
4141
LLM_TYPE_780M,
42+
LLM_TYPE_0_3B,
4243
LLM_TYPE_0_5B,
4344
LLM_TYPE_0_6B,
4445
LLM_TYPE_1B,

0 commit comments

Comments
 (0)