Skip to content

Commit 8775bc4

Browse files
committed
test impl
1 parent c50e627 commit 8775bc4

File tree

2 files changed

+217
-21
lines changed

2 files changed

+217
-21
lines changed

tools/llava/clip-impl.h

+3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
6161
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
6262
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
63+
#define TN_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s"
6364
#define TN_LN_1 "%s.blk.%d.ln1.%s"
6465
#define TN_LN_2 "%s.blk.%d.ln2.%s"
6566
#define TN_LN_PRE "%s.pre_ln.%s"
@@ -103,6 +104,7 @@ enum projector_type {
103104
PROJECTOR_TYPE_IDEFICS3,
104105
PROJECTOR_TYPE_PIXTRAL,
105106
PROJECTOR_TYPE_QWEN25VL,
107+
PROJECTOR_TYPE_LLAMA4,
106108
PROJECTOR_TYPE_UNKNOWN,
107109
};
108110

@@ -117,6 +119,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
117119
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
118120
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
119121
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
122+
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
120123
};
121124

122125
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/llava/clip.cpp

+214-21
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ struct clip_hparams {
176176
};
177177

178178
struct clip_layer {
179+
// layernorm 1 (input norm)
180+
struct ggml_tensor * ln_1_w = nullptr;
181+
struct ggml_tensor * ln_1_b = nullptr;
182+
179183
// attention
180184
struct ggml_tensor * k_w = nullptr;
181185
struct ggml_tensor * k_b = nullptr;
@@ -187,29 +191,28 @@ struct clip_layer {
187191
struct ggml_tensor * o_w = nullptr;
188192
struct ggml_tensor * o_b = nullptr;
189193

190-
// layernorm 1
191-
struct ggml_tensor * ln_1_w = nullptr;
192-
struct ggml_tensor * ln_1_b = nullptr;
194+
// layernorm 2 (post-attn norm / pre-ffn norm)
195+
struct ggml_tensor * ln_2_w = nullptr;
196+
struct ggml_tensor * ln_2_b = nullptr;
193197

194198
// ff
195199
struct ggml_tensor * ff_i_w = nullptr; // legacy naming
196200
struct ggml_tensor * ff_i_b = nullptr; // legacy naming
197201
struct ggml_tensor * ff_o_w = nullptr; // legacy naming
198202
struct ggml_tensor * ff_o_b = nullptr; // legacy naming
203+
struct ggml_tensor * ff_g_w = nullptr; // legacy naming
204+
struct ggml_tensor * ff_g_b = nullptr; // legacy naming
199205

200-
struct ggml_tensor * ff_up_w = nullptr;
201-
struct ggml_tensor * ff_up_b = nullptr;
206+
struct ggml_tensor * ff_up_w = nullptr;
207+
struct ggml_tensor * ff_up_b = nullptr;
202208
struct ggml_tensor * ff_gate_w = nullptr;
203209
struct ggml_tensor * ff_gate_b = nullptr;
204210
struct ggml_tensor * ff_down_w = nullptr;
205211
struct ggml_tensor * ff_down_b = nullptr;
206212

207-
struct ggml_tensor * ff_g_w = NULL;
208-
struct ggml_tensor * ff_g_b = NULL;
209-
210-
// layernorm 2
211-
struct ggml_tensor * ln_2_w = nullptr;
212-
struct ggml_tensor * ln_2_b = nullptr;
213+
// post-ffn norm (output layer norm)
214+
struct ggml_tensor * post_ffn_norm_w = nullptr;
215+
struct ggml_tensor * post_ffn_norm_b = nullptr;
213216
};
214217

215218
struct clip_vision_model {
@@ -560,9 +563,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
560563
static ggml_tensor * build_rope_2d(
561564
ggml_context * ctx0,
562565
ggml_tensor * cur,
563-
ggml_tensor * pos_h,
564-
ggml_tensor * pos_w,
565-
const float freq_base
566+
ggml_tensor * pos_a, // first half
567+
ggml_tensor * pos_b, // second half
568+
const float freq_base,
569+
const bool interleave_freq
566570
) {
567571
const int64_t n_dim = cur->ne[0];
568572
const int64_t n_head = cur->ne[1];
@@ -576,7 +580,9 @@ static ggml_tensor * build_rope_2d(
576580
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
577581
// then for the second half, we use freq_scale to shift the inv_freq
578582
// ^ why? replace (2i) with (2i+1) in the above equation
579-
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
583+
const float freq_scale_odd = interleave_freq
584+
? std::pow(freq_base, (float)-2/n_dim)
585+
: 1.0;
580586

581587
// first half
582588
ggml_tensor * first;
@@ -589,7 +595,7 @@ static ggml_tensor * build_rope_2d(
589595
first = ggml_rope_ext(
590596
ctx0,
591597
first,
592-
pos_h, // positions
598+
pos_a, // positions
593599
nullptr, // freq factors
594600
n_dim/2, // n_dims
595601
0, 0, freq_base,
@@ -609,7 +615,7 @@ static ggml_tensor * build_rope_2d(
609615
second = ggml_rope_ext(
610616
ctx0,
611617
second,
612-
pos_w, // positions
618+
pos_b, // positions
613619
nullptr, // freq factors
614620
n_dim/2, // n_dims
615621
0, 0, freq_base,
@@ -687,13 +693,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
687693
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
688694

689695
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
690-
Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
696+
Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta, true);
691697
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
692698

693699
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
694700

695701
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
696-
K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
702+
K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta, true);
697703
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
698704

699705
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
@@ -809,6 +815,174 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
809815
return gf;
810816
}
811817

818+
static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) {
819+
const auto & model = ctx->vision_model;
820+
const auto & hparams = model.hparams;
821+
822+
const int patch_size = hparams.patch_size;
823+
const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size));
824+
const int hidden_size = hparams.hidden_size;
825+
const int n_head = hparams.n_head;
826+
const int d_head = hidden_size / n_head;
827+
const int n_layer = hparams.n_layer;
828+
const float eps = hparams.eps;
829+
830+
struct ggml_init_params params = {
831+
/*.mem_size =*/ ctx->buf_compute_meta.size(),
832+
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
833+
/*.no_alloc =*/ true,
834+
};
835+
836+
ggml_context_ptr ctx0_ptr(ggml_init(params));
837+
auto ctx0 = ctx0_ptr.get();
838+
839+
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
840+
841+
// input raw
842+
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
843+
ggml_set_name(inp_raw, "inp_raw");
844+
ggml_set_input(inp_raw);
845+
846+
// 2D input positions
847+
struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
848+
ggml_set_name(pos_h, "pos_h");
849+
ggml_set_input(pos_h);
850+
struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
851+
ggml_set_name(pos_w, "pos_w");
852+
ggml_set_input(pos_w);
853+
854+
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
855+
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
856+
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
857+
inp = ggml_add(ctx0, inp, model.patch_bias);
858+
859+
// position embeddings
860+
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
861+
862+
// loop over layers
863+
for (int il = 0; il < n_layer; il++) {
864+
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
865+
866+
// layernorm1
867+
{
868+
cur = ggml_norm(ctx0, cur, eps);
869+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
870+
}
871+
872+
// self-attention
873+
{
874+
875+
struct ggml_tensor * Q =
876+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
877+
878+
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
879+
Q = build_rope_2d(ctx0, Q, pos_w, pos_h, hparams.rope_theta, false);
880+
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
881+
882+
struct ggml_tensor * K =
883+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
884+
885+
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
886+
K = build_rope_2d(ctx0, K, pos_w, pos_h, hparams.rope_theta, false);
887+
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
888+
889+
struct ggml_tensor * V =
890+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
891+
892+
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
893+
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
894+
895+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
896+
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
897+
898+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
899+
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
900+
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
901+
902+
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
903+
}
904+
905+
// attention output
906+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
907+
908+
// re-add the layer input, e.g., residual
909+
cur = ggml_add(ctx0, cur, embeddings);
910+
911+
embeddings = cur; // embeddings = residual, cur = hidden_states
912+
913+
// layernorm2
914+
{
915+
cur = ggml_norm(ctx0, cur, eps);
916+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
917+
}
918+
919+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
920+
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
921+
922+
if (ctx->use_silu) {
923+
cur = ggml_silu(ctx0, cur);
924+
} else if (ctx->use_gelu) {
925+
cur = ggml_gelu(ctx0, cur);
926+
} else {
927+
GGML_ABORT("llama4: Unsupported activation");
928+
}
929+
930+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
931+
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
932+
933+
// residual 2
934+
cur = ggml_add(ctx0, embeddings, cur);
935+
936+
// norm output
937+
{
938+
cur = ggml_norm(ctx0, cur, eps);
939+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].post_ffn_norm_w), model.layers[il].post_ffn_norm_b);
940+
}
941+
942+
embeddings = cur;
943+
}
944+
945+
// post-layernorm
946+
if (model.post_ln_w) {
947+
embeddings = ggml_norm(ctx0, embeddings, eps);
948+
ggml_set_name(embeddings, "post_ln");
949+
950+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
951+
}
952+
953+
// Llama4VisionPixelShuffleMLP
954+
{
955+
ggml_tensor * cur = embeddings;
956+
const int scale_factor = model.hparams.proj_scale_factor;
957+
const int n_embd = cur->ne[0];
958+
const int seq = cur->ne[1];
959+
const int bsz = 1; // batch size, always 1 for now since we don't support batching
960+
const int height = std::sqrt(seq);
961+
const int width = std::sqrt(seq);
962+
GGML_ASSERT(scale_factor != 0);
963+
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
964+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
965+
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
966+
n_embd * scale_factor * scale_factor,
967+
height / scale_factor,
968+
width / scale_factor,
969+
bsz);
970+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
971+
cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
972+
n_embd * scale_factor * scale_factor,
973+
seq / (scale_factor * scale_factor),
974+
bsz);
975+
976+
cur = ggml_mul_mat(ctx0, model.projection, cur);
977+
embeddings = cur;
978+
}
979+
980+
// build the graph
981+
ggml_build_forward_expand(gf, embeddings);
982+
983+
return gf;
984+
}
985+
812986
static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
813987
const auto & model = ctx->vision_model;
814988
const auto & hparams = model.hparams;
@@ -1599,6 +1773,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
15991773
{
16001774
res = clip_image_build_graph_qwen25vl(ctx, imgs);
16011775
} break;
1776+
case PROJECTOR_TYPE_LLAMA4:
1777+
{
1778+
res = clip_image_build_graph_llama4(ctx, *imgs.entries[0]);
1779+
} break;
16021780
default:
16031781
{
16041782
// TODO: we should have one build_* function per model
@@ -1781,6 +1959,10 @@ struct clip_model_loader {
17811959
{
17821960
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
17831961
} break;
1962+
case PROJECTOR_TYPE_LLAMA4:
1963+
{
1964+
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
1965+
} break;
17841966
default:
17851967
break;
17861968
}
@@ -1867,6 +2049,9 @@ struct clip_model_loader {
18672049
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
18682050
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
18692051

2052+
layer.post_ffn_norm_b = get_tensor(string_format(TN_FFN_POST_NORM, "v", il, "bias"), false);
2053+
layer.post_ffn_norm_w = get_tensor(string_format(TN_FFN_POST_NORM, "v", il, "weight"), false);
2054+
18702055
// new naming
18712056
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
18722057
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
@@ -2008,6 +2193,12 @@ struct clip_model_loader {
20082193
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
20092194
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
20102195
} break;
2196+
case PROJECTOR_TYPE_LLAMA4:
2197+
{
2198+
vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
2199+
vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
2200+
vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
2201+
} break;
20112202
default:
20122203
GGML_ASSERT(false && "unknown projector type");
20132204
}
@@ -2796,7 +2987,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
27962987
}
27972988
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
27982989
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
2799-
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
2990+
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
2991+
|| ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
28002992
clip_image_u8 resized_image;
28012993
int sz = params.image_size;
28022994
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
@@ -2968,7 +3160,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
29683160
n_patches = x_patch * y_patch;
29693161
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
29703162
n_patches = 256;
2971-
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
3163+
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
29723164
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
29733165
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
29743166
int n_merge = ctx->vision_model.hparams.spatial_merge_size;
@@ -3550,6 +3742,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
35503742
case PROJECTOR_TYPE_GEMMA3:
35513743
return ctx->vision_model.mm_input_proj_w->ne[0];
35523744
case PROJECTOR_TYPE_IDEFICS3:
3745+
case PROJECTOR_TYPE_LLAMA4:
35533746
return ctx->vision_model.projection->ne[1];
35543747
default:
35553748
GGML_ABORT("Unknown projector type");

0 commit comments

Comments
 (0)