Skip to content

Commit 9d1a4d6

Browse files
committed
Llama4UnfoldConvolution
1 parent 224cbba commit 9d1a4d6

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

tools/llava/clip.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -842,15 +842,15 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
842842
ggml_set_name(inp_raw, "inp_raw");
843843
ggml_set_input(inp_raw);
844844

845-
// create patches
846-
ggml_tensor * patch_embd_view = ggml_view_4d(ctx0, model.patch_embeddings_0,
847-
patch_size, patch_size, 3, hidden_size,
848-
ggml_row_size(model.patch_embeddings_0->type, patch_size),
849-
ggml_row_size(model.patch_embeddings_0->type, patch_size * patch_size),
850-
ggml_row_size(model.patch_embeddings_0->type, patch_size * patch_size * 3), 0);
851-
ggml_tensor * inp = ggml_conv_2d(ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
852-
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
853-
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
845+
// Llama4UnfoldConvolution
846+
ggml_tensor * inp;
847+
{
848+
ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
849+
patch_size, patch_size, 3, hidden_size);
850+
inp = ggml_im2col(ctx0, kernel, inp_raw, patch_size, patch_size, 0, 0, 1, 1, true, inp_raw->type);
851+
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
852+
inp = ggml_reshape_2d(ctx0, inp, hidden_size, num_patches);
853+
}
854854

855855
// add CLS
856856
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
@@ -3578,12 +3578,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
35783578
// last pos is always kept 0, it's for CLS
35793579
// dimension H
35803580
for (int i = 0; i < num_patches; i++) {
3581-
pos_data[i] = i / n_patches_per_col;
3581+
pos_data[i] = (i / n_patches_per_col) + 1;
35823582
}
35833583
set_input_i32("pos_h", pos_data);
35843584
// dimension W
35853585
for (int i = 0; i < num_patches; i++) {
3586-
pos_data[i] = i % n_patches_per_col;
3586+
pos_data[i] = (i % n_patches_per_col) + 1;
35873587
}
35883588
set_input_i32("pos_w", pos_data);
35893589
} break;

0 commit comments

Comments
 (0)