@@ -52,6 +52,7 @@ const char * llm_type_name(llm_type type) {
52
52
case LLM_TYPE_475M: return "475M";
53
53
case LLM_TYPE_770M: return "770M";
54
54
case LLM_TYPE_780M: return "780M";
55
+ case LLM_TYPE_0_3B: return "0.3B";
55
56
case LLM_TYPE_0_5B: return "0.5B";
56
57
case LLM_TYPE_0_6B: return "0.6B";
57
58
case LLM_TYPE_1B: return "1B";
@@ -1509,6 +1510,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1509
1510
default: type = LLM_TYPE_UNKNOWN;
1510
1511
}
1511
1512
} break;
1513
+ case LLM_ARCH_ERNIE4_5:
1514
+ {
1515
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1516
+ switch (hparams.n_layer) {
1517
+ case 18: type = LLM_TYPE_0_3B; break;
1518
+ default: type = LLM_TYPE_UNKNOWN;
1519
+ }
1520
+ } break;
1512
1521
default: throw std::runtime_error("unsupported model architecture");
1513
1522
}
1514
1523
@@ -4440,6 +4449,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4440
4449
4441
4450
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4442
4451
4452
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4453
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4454
+ }
4455
+ } break;
4456
+ case LLM_ARCH_ERNIE4_5:
4457
+ {
4458
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4459
+
4460
+ // output
4461
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4462
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4463
+ // if output is NULL, init from the input tok embed
4464
+ if (output == NULL) {
4465
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4466
+ }
4467
+
4468
+ for (int i = 0; i < n_layer; ++i) {
4469
+ auto & layer = layers[i];
4470
+
4471
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4472
+
4473
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4474
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
4475
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
4476
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4477
+
4478
+ // optional bias tensors
4479
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4480
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4481
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4482
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4483
+
4484
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4485
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4443
4486
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4444
4487
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4445
4488
}
@@ -14225,6 +14268,136 @@ struct llm_build_dots1 : public llm_graph_context {
14225
14268
}
14226
14269
};
14227
14270
14271
+ struct llm_build_ernie4_5 : public llm_graph_context {
14272
+ llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14273
+ const int64_t n_embd_head = hparams.n_embd_head_v;
14274
+
14275
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14276
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
14277
+
14278
+ ggml_tensor * cur;
14279
+ ggml_tensor * inpL;
14280
+
14281
+ inpL = build_inp_embd(model.tok_embd);
14282
+
14283
+ // inp_pos - contains the positions
14284
+ ggml_tensor * inp_pos = build_inp_pos();
14285
+
14286
+ auto * inp_attn = build_attn_inp_kv_unified();
14287
+
14288
+ for (int il = 0; il < n_layer; ++il) {
14289
+ ggml_tensor * inpSA = inpL;
14290
+
14291
+ // norm
14292
+ {
14293
+ cur = build_norm(inpL,
14294
+ model.layers[il].attn_norm, NULL,
14295
+ LLM_NORM_RMS, il);
14296
+ cb(cur, "attn_norm", il);
14297
+ }
14298
+
14299
+ // self-attention
14300
+ {
14301
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14302
+ cb(Qcur, "Qcur", il);
14303
+ if (model.layers[il].bq) {
14304
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14305
+ cb(Qcur, "Qcur", il);
14306
+ }
14307
+
14308
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14309
+ cb(Kcur, "Kcur", il);
14310
+ if (model.layers[il].bk) {
14311
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14312
+ cb(Kcur, "Kcur", il);
14313
+ }
14314
+
14315
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14316
+ cb(Vcur, "Vcur", il);
14317
+ if (model.layers[il].bv) {
14318
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14319
+ cb(Vcur, "Vcur", il);
14320
+ }
14321
+
14322
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14323
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14324
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14325
+
14326
+ Qcur = ggml_rope_ext(
14327
+ ctx0, Qcur, inp_pos, nullptr,
14328
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14329
+ ext_factor, attn_factor, beta_fast, beta_slow
14330
+ );
14331
+
14332
+ Kcur = ggml_rope_ext(
14333
+ ctx0, Kcur, inp_pos, nullptr,
14334
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14335
+ ext_factor, attn_factor, beta_fast, beta_slow
14336
+ );
14337
+
14338
+ cb(Qcur, "Qcur", il);
14339
+ cb(Kcur, "Kcur", il);
14340
+ cb(Vcur, "Vcur", il);
14341
+
14342
+ cur = build_attn(inp_attn, gf,
14343
+ model.layers[il].wo, NULL,
14344
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14345
+ }
14346
+
14347
+ if (il == n_layer - 1) {
14348
+ // skip computing output for unused tokens
14349
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
14350
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14351
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14352
+ }
14353
+
14354
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14355
+ cb(ffn_inp, "ffn_inp", il);
14356
+
14357
+ // feed-forward network
14358
+ {
14359
+ cur = build_norm(ffn_inp,
14360
+ model.layers[il].ffn_norm, NULL,
14361
+ LLM_NORM_RMS, il);
14362
+ cb(cur, "ffn_norm", il);
14363
+
14364
+ cur = build_ffn(cur,
14365
+ model.layers[il].ffn_up, NULL, NULL,
14366
+ model.layers[il].ffn_gate, NULL, NULL,
14367
+ model.layers[il].ffn_down, NULL, NULL,
14368
+ NULL,
14369
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14370
+ cb(cur, "ffn_out", il);
14371
+ }
14372
+
14373
+ cur = ggml_add(ctx0, cur, ffn_inp);
14374
+
14375
+ cur = build_cvec(cur, il);
14376
+ cb(cur, "l_out", il);
14377
+
14378
+ // input for next layer
14379
+ inpL = cur;
14380
+ }
14381
+
14382
+ cur = inpL;
14383
+
14384
+ cur = build_norm(cur,
14385
+ model.output_norm, NULL,
14386
+ LLM_NORM_RMS, -1);
14387
+
14388
+ cb(cur, "result_norm", -1);
14389
+ res->t_embd = cur;
14390
+
14391
+ // lm_head
14392
+ cur = build_lora_mm(model.output, cur);
14393
+
14394
+ cb(cur, "result_output", -1);
14395
+ res->t_logits = cur;
14396
+
14397
+ ggml_build_forward_expand(gf, cur);
14398
+ }
14399
+ };
14400
+
14228
14401
struct llm_build_arcee : public llm_graph_context {
14229
14402
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14230
14403
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14735,6 +14908,10 @@ llm_graph_result_ptr llama_model::build_graph(
14735
14908
{
14736
14909
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
14737
14910
} break;
14911
+ case LLM_ARCH_ERNIE4_5:
14912
+ {
14913
+ llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
14914
+ } break;
14738
14915
default:
14739
14916
GGML_ABORT("fatal error");
14740
14917
}
@@ -14886,6 +15063,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
14886
15063
case LLM_ARCH_BAILINGMOE:
14887
15064
case LLM_ARCH_NEO_BERT:
14888
15065
case LLM_ARCH_ARCEE:
15066
+ case LLM_ARCH_ERNIE4_5:
14889
15067
return LLAMA_ROPE_TYPE_NORM;
14890
15068
14891
15069
// the pairs of head values are offset by n_rot/2
0 commit comments