@@ -176,6 +176,10 @@ struct clip_hparams {
176
176
};
177
177
178
178
struct clip_layer {
179
+ // layernorm 1 (input norm)
180
+ struct ggml_tensor * ln_1_w = nullptr ;
181
+ struct ggml_tensor * ln_1_b = nullptr ;
182
+
179
183
// attention
180
184
struct ggml_tensor * k_w = nullptr ;
181
185
struct ggml_tensor * k_b = nullptr ;
@@ -187,29 +191,28 @@ struct clip_layer {
187
191
struct ggml_tensor * o_w = nullptr ;
188
192
struct ggml_tensor * o_b = nullptr ;
189
193
190
- // layernorm 1
191
- struct ggml_tensor * ln_1_w = nullptr ;
192
- struct ggml_tensor * ln_1_b = nullptr ;
194
+ // layernorm 2 (post-attn norm / pre-ffn norm)
195
+ struct ggml_tensor * ln_2_w = nullptr ;
196
+ struct ggml_tensor * ln_2_b = nullptr ;
193
197
194
198
// ff
195
199
struct ggml_tensor * ff_i_w = nullptr ; // legacy naming
196
200
struct ggml_tensor * ff_i_b = nullptr ; // legacy naming
197
201
struct ggml_tensor * ff_o_w = nullptr ; // legacy naming
198
202
struct ggml_tensor * ff_o_b = nullptr ; // legacy naming
203
+ struct ggml_tensor * ff_g_w = nullptr ; // legacy naming
204
+ struct ggml_tensor * ff_g_b = nullptr ; // legacy naming
199
205
200
- struct ggml_tensor * ff_up_w = nullptr ;
201
- struct ggml_tensor * ff_up_b = nullptr ;
206
+ struct ggml_tensor * ff_up_w = nullptr ;
207
+ struct ggml_tensor * ff_up_b = nullptr ;
202
208
struct ggml_tensor * ff_gate_w = nullptr ;
203
209
struct ggml_tensor * ff_gate_b = nullptr ;
204
210
struct ggml_tensor * ff_down_w = nullptr ;
205
211
struct ggml_tensor * ff_down_b = nullptr ;
206
212
207
- struct ggml_tensor * ff_g_w = NULL ;
208
- struct ggml_tensor * ff_g_b = NULL ;
209
-
210
- // layernorm 2
211
- struct ggml_tensor * ln_2_w = nullptr ;
212
- struct ggml_tensor * ln_2_b = nullptr ;
213
+ // post-ffn norm (output layer norm)
214
+ struct ggml_tensor * post_ffn_norm_w = nullptr ;
215
+ struct ggml_tensor * post_ffn_norm_b = nullptr ;
213
216
};
214
217
215
218
struct clip_vision_model {
@@ -560,9 +563,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
560
563
static ggml_tensor * build_rope_2d (
561
564
ggml_context * ctx0,
562
565
ggml_tensor * cur,
563
- ggml_tensor * pos_h,
564
- ggml_tensor * pos_w,
565
- const float freq_base
566
+ ggml_tensor * pos_a, // first half
567
+ ggml_tensor * pos_b, // second half
568
+ const float freq_base,
569
+ const bool interleave_freq
566
570
) {
567
571
const int64_t n_dim = cur->ne [0 ];
568
572
const int64_t n_head = cur->ne [1 ];
@@ -576,7 +580,9 @@ static ggml_tensor * build_rope_2d(
576
580
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
577
581
// then for the second half, we use freq_scale to shift the inv_freq
578
582
// ^ why? replace (2i) with (2i+1) in the above equation
579
- const float freq_scale_odd = std::pow (freq_base, (float )-2 /n_dim);
583
+ const float freq_scale_odd = interleave_freq
584
+ ? std::pow (freq_base, (float )-2 /n_dim)
585
+ : 1.0 ;
580
586
581
587
// first half
582
588
ggml_tensor * first;
@@ -589,7 +595,7 @@ static ggml_tensor * build_rope_2d(
589
595
first = ggml_rope_ext (
590
596
ctx0,
591
597
first,
592
- pos_h , // positions
598
+ pos_a , // positions
593
599
nullptr , // freq factors
594
600
n_dim/2 , // n_dims
595
601
0 , 0 , freq_base,
@@ -609,7 +615,7 @@ static ggml_tensor * build_rope_2d(
609
615
second = ggml_rope_ext (
610
616
ctx0,
611
617
second,
612
- pos_w , // positions
618
+ pos_b , // positions
613
619
nullptr , // freq factors
614
620
n_dim/2 , // n_dims
615
621
0 , 0 , freq_base,
@@ -687,13 +693,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
687
693
struct ggml_tensor * Q = ggml_mul_mat (ctx0, model.layers [il].q_w , cur);
688
694
689
695
Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_patches);
690
- Q = build_rope_2d (ctx0, Q, pos_h, pos_w, hparams.rope_theta );
696
+ Q = build_rope_2d (ctx0, Q, pos_h, pos_w, hparams.rope_theta , true );
691
697
Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
692
698
693
699
struct ggml_tensor * K = ggml_mul_mat (ctx0, model.layers [il].k_w , cur);
694
700
695
701
K = ggml_reshape_3d (ctx0, K, d_head, n_head, num_patches);
696
- K = build_rope_2d (ctx0, K, pos_h, pos_w, hparams.rope_theta );
702
+ K = build_rope_2d (ctx0, K, pos_h, pos_w, hparams.rope_theta , true );
697
703
K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
698
704
699
705
struct ggml_tensor * V = ggml_mul_mat (ctx0, model.layers [il].v_w , cur);
@@ -809,6 +815,174 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
809
815
return gf;
810
816
}
811
817
818
+ static ggml_cgraph * clip_image_build_graph_llama4 (clip_ctx * ctx, const clip_image_f32 & img) {
819
+ const auto & model = ctx->vision_model ;
820
+ const auto & hparams = model.hparams ;
821
+
822
+ const int patch_size = hparams.patch_size ;
823
+ const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size));
824
+ const int hidden_size = hparams.hidden_size ;
825
+ const int n_head = hparams.n_head ;
826
+ const int d_head = hidden_size / n_head;
827
+ const int n_layer = hparams.n_layer ;
828
+ const float eps = hparams.eps ;
829
+
830
+ struct ggml_init_params params = {
831
+ /* .mem_size =*/ ctx->buf_compute_meta .size (),
832
+ /* .mem_buffer =*/ ctx->buf_compute_meta .data (),
833
+ /* .no_alloc =*/ true ,
834
+ };
835
+
836
+ ggml_context_ptr ctx0_ptr (ggml_init (params));
837
+ auto ctx0 = ctx0_ptr.get ();
838
+
839
+ struct ggml_cgraph * gf = ggml_new_graph (ctx0);
840
+
841
+ // input raw
842
+ struct ggml_tensor * inp_raw = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, img.nx , img.ny , 3 );
843
+ ggml_set_name (inp_raw, " inp_raw" );
844
+ ggml_set_input (inp_raw);
845
+
846
+ // 2D input positions
847
+ struct ggml_tensor * pos_h = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, num_patches);
848
+ ggml_set_name (pos_h, " pos_h" );
849
+ ggml_set_input (pos_h);
850
+ struct ggml_tensor * pos_w = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, num_patches);
851
+ ggml_set_name (pos_w, " pos_w" );
852
+ ggml_set_input (pos_w);
853
+
854
+ struct ggml_tensor * inp = ggml_conv_2d (ctx0, model.patch_embeddings_0 , inp_raw, patch_size, patch_size, 0 , 0 , 1 , 1 );
855
+ inp = ggml_reshape_2d (ctx0, inp, num_patches, hidden_size);
856
+ inp = ggml_cont (ctx0, ggml_transpose (ctx0, inp));
857
+ inp = ggml_add (ctx0, inp, model.patch_bias );
858
+
859
+ // position embeddings
860
+ struct ggml_tensor * embeddings = ggml_add (ctx0, inp, model.position_embeddings );
861
+
862
+ // loop over layers
863
+ for (int il = 0 ; il < n_layer; il++) {
864
+ struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
865
+
866
+ // layernorm1
867
+ {
868
+ cur = ggml_norm (ctx0, cur, eps);
869
+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].ln_1_w ), model.layers [il].ln_1_b );
870
+ }
871
+
872
+ // self-attention
873
+ {
874
+
875
+ struct ggml_tensor * Q =
876
+ ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].q_w , cur), model.layers [il].q_b );
877
+
878
+ Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_patches);
879
+ Q = build_rope_2d (ctx0, Q, pos_w, pos_h, hparams.rope_theta , false );
880
+ Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
881
+
882
+ struct ggml_tensor * K =
883
+ ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].k_w , cur), model.layers [il].k_b );
884
+
885
+ K = ggml_reshape_3d (ctx0, K, d_head, n_head, num_patches);
886
+ K = build_rope_2d (ctx0, K, pos_w, pos_h, hparams.rope_theta , false );
887
+ K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
888
+
889
+ struct ggml_tensor * V =
890
+ ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].v_w , cur), model.layers [il].v_b );
891
+
892
+ V = ggml_reshape_3d (ctx0, V, d_head, n_head, num_patches);
893
+ V = ggml_cont (ctx0, ggml_permute (ctx0, V, 1 , 2 , 0 , 3 ));
894
+
895
+ struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
896
+ KQ = ggml_soft_max_ext (ctx0, KQ, nullptr , 1 .0f / sqrtf ((float )d_head), 0 .0f );
897
+
898
+ struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ);
899
+ KQV = ggml_reshape_3d (ctx0, KQV, d_head, num_patches, n_head);
900
+ KQV = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
901
+
902
+ cur = ggml_cont_2d (ctx0, KQV, hidden_size, num_patches);
903
+ }
904
+
905
+ // attention output
906
+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].o_w , cur), model.layers [il].o_b );
907
+
908
+ // re-add the layer input, e.g., residual
909
+ cur = ggml_add (ctx0, cur, embeddings);
910
+
911
+ embeddings = cur; // embeddings = residual, cur = hidden_states
912
+
913
+ // layernorm2
914
+ {
915
+ cur = ggml_norm (ctx0, cur, eps);
916
+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].ln_2_w ), model.layers [il].ln_2_b );
917
+ }
918
+
919
+ cur = ggml_mul_mat (ctx0, model.layers [il].ff_i_w , cur);
920
+ cur = ggml_add (ctx0, cur, model.layers [il].ff_i_b );
921
+
922
+ if (ctx->use_silu ) {
923
+ cur = ggml_silu (ctx0, cur);
924
+ } else if (ctx->use_gelu ) {
925
+ cur = ggml_gelu (ctx0, cur);
926
+ } else {
927
+ GGML_ABORT (" llama4: Unsupported activation" );
928
+ }
929
+
930
+ cur = ggml_mul_mat (ctx0, model.layers [il].ff_o_w , cur);
931
+ cur = ggml_add (ctx0, cur, model.layers [il].ff_o_b );
932
+
933
+ // residual 2
934
+ cur = ggml_add (ctx0, embeddings, cur);
935
+
936
+ // norm output
937
+ {
938
+ cur = ggml_norm (ctx0, cur, eps);
939
+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].post_ffn_norm_w ), model.layers [il].post_ffn_norm_b );
940
+ }
941
+
942
+ embeddings = cur;
943
+ }
944
+
945
+ // post-layernorm
946
+ if (model.post_ln_w ) {
947
+ embeddings = ggml_norm (ctx0, embeddings, eps);
948
+ ggml_set_name (embeddings, " post_ln" );
949
+
950
+ embeddings = ggml_add (ctx0, ggml_mul (ctx0, embeddings, model.post_ln_w ), model.post_ln_b );
951
+ }
952
+
953
+ // Llama4VisionPixelShuffleMLP
954
+ {
955
+ ggml_tensor * cur = embeddings;
956
+ const int scale_factor = model.hparams .proj_scale_factor ;
957
+ const int n_embd = cur->ne [0 ];
958
+ const int seq = cur->ne [1 ];
959
+ const int bsz = 1 ; // batch size, always 1 for now since we don't support batching
960
+ const int height = std::sqrt (seq);
961
+ const int width = std::sqrt (seq);
962
+ GGML_ASSERT (scale_factor != 0 );
963
+ cur = ggml_reshape_4d (ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
964
+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
965
+ cur = ggml_reshape_4d (ctx0, ggml_cont (ctx0, cur),
966
+ n_embd * scale_factor * scale_factor,
967
+ height / scale_factor,
968
+ width / scale_factor,
969
+ bsz);
970
+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
971
+ cur = ggml_reshape_3d (ctx0, ggml_cont (ctx0, cur),
972
+ n_embd * scale_factor * scale_factor,
973
+ seq / (scale_factor * scale_factor),
974
+ bsz);
975
+
976
+ cur = ggml_mul_mat (ctx0, model.projection , cur);
977
+ embeddings = cur;
978
+ }
979
+
980
+ // build the graph
981
+ ggml_build_forward_expand (gf, embeddings);
982
+
983
+ return gf;
984
+ }
985
+
812
986
static ggml_cgraph * clip_image_build_graph_qwen25vl (clip_ctx * ctx, const clip_image_f32_batch & imgs) {
813
987
const auto & model = ctx->vision_model ;
814
988
const auto & hparams = model.hparams ;
@@ -1599,6 +1773,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
1599
1773
{
1600
1774
res = clip_image_build_graph_qwen25vl (ctx, imgs);
1601
1775
} break ;
1776
+ case PROJECTOR_TYPE_LLAMA4:
1777
+ {
1778
+ res = clip_image_build_graph_llama4 (ctx, *imgs.entries [0 ]);
1779
+ } break ;
1602
1780
default :
1603
1781
{
1604
1782
// TODO: we should have one build_* function per model
@@ -1781,6 +1959,10 @@ struct clip_model_loader {
1781
1959
{
1782
1960
get_u32 (KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern );
1783
1961
} break ;
1962
+ case PROJECTOR_TYPE_LLAMA4:
1963
+ {
1964
+ get_u32 (KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor );
1965
+ } break ;
1784
1966
default :
1785
1967
break ;
1786
1968
}
@@ -1867,6 +2049,9 @@ struct clip_model_loader {
1867
2049
layer.ln_1_b = get_tensor (string_format (TN_LN_1, " v" , il, " bias" ), false );
1868
2050
layer.ln_2_b = get_tensor (string_format (TN_LN_2, " v" , il, " bias" ), false );
1869
2051
2052
+ layer.post_ffn_norm_b = get_tensor (string_format (TN_FFN_POST_NORM, " v" , il, " bias" ), false );
2053
+ layer.post_ffn_norm_w = get_tensor (string_format (TN_FFN_POST_NORM, " v" , il, " weight" ), false );
2054
+
1870
2055
// new naming
1871
2056
layer.ff_up_w = get_tensor (string_format (TN_FFN_UP, " v" , il, " weight" ));
1872
2057
layer.ff_up_b = get_tensor (string_format (TN_FFN_UP, " v" , il, " bias" ), false );
@@ -2008,6 +2193,12 @@ struct clip_model_loader {
2008
2193
vision_model.mm_input_norm_w = get_tensor (TN_MM_INP_NORM, false );
2009
2194
vision_model.mm_patch_merger_w = get_tensor (TN_MM_PATCH_MERGER, false );
2010
2195
} break ;
2196
+ case PROJECTOR_TYPE_LLAMA4:
2197
+ {
2198
+ vision_model.mm_model_proj = get_tensor (TN_MM_PROJECTOR);
2199
+ vision_model.mm_model_mlp_1_w = get_tensor (string_format (TN_MVLM_PROJ_MLP, 1 , " weight" ));
2200
+ vision_model.mm_model_mlp_2_w = get_tensor (string_format (TN_MVLM_PROJ_MLP, 2 , " weight" ));
2201
+ } break ;
2011
2202
default :
2012
2203
GGML_ASSERT (false && " unknown projector type" );
2013
2204
}
@@ -2796,7 +2987,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
2796
2987
}
2797
2988
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
2798
2989
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
2799
- || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
2990
+ || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
2991
+ || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
2800
2992
clip_image_u8 resized_image;
2801
2993
int sz = params.image_size ;
2802
2994
image_manipulation::resize_and_pad_image (*img, resized_image, {sz, sz});
@@ -2968,7 +3160,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
2968
3160
n_patches = x_patch * y_patch;
2969
3161
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
2970
3162
n_patches = 256 ;
2971
- } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
3163
+ } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx-> proj_type == PROJECTOR_TYPE_LLAMA4 ) {
2972
3164
n_patches /= ctx->vision_model .hparams .proj_scale_factor ;
2973
3165
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
2974
3166
int n_merge = ctx->vision_model .hparams .spatial_merge_size ;
@@ -3550,6 +3742,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
3550
3742
case PROJECTOR_TYPE_GEMMA3:
3551
3743
return ctx->vision_model .mm_input_proj_w ->ne [0 ];
3552
3744
case PROJECTOR_TYPE_IDEFICS3:
3745
+ case PROJECTOR_TYPE_LLAMA4:
3553
3746
return ctx->vision_model .projection ->ne [1 ];
3554
3747
default :
3555
3748
GGML_ABORT (" Unknown projector type" );
0 commit comments