diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 838999531e580..cbc88134795a1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1778,6 +1778,12 @@ class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA undo_permute = True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # fix for SmolVLM2, missing `num_attention_heads` in config.json + if self.hf_arch == "VLlama3ForCausalLM": + self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) + def set_vocab(self): try: self._set_vocab_sentencepiece() diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2b089f84a841a..003b0172c77b0 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -977,15 +977,12 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl ), - # some namings are messed up because the original llava code swapped fc1 and fc2 - # we have no better way to fix it, just be careful - # new models like pixtral use the correct naming MODEL_TENSOR.V_ENC_FFN_UP: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vpm.encoder.layers.{bid}.mlp.fc1", - "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped) + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral - "visual.blocks.{bid}.mlp.fc2", # qwen2vl + "visual.blocks.{bid}.mlp.fc1", # qwen2vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl ), @@ -997,9 +994,9 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_FFN_DOWN: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", "vpm.encoder.layers.{bid}.mlp.fc2", - "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped) + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral - "visual.blocks.{bid}.mlp.fc1", # qwen2vl + "visual.blocks.{bid}.mlp.fc2", # qwen2vl "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index 3b60a526eedd8..8bd5e790f4394 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -155,8 +155,8 @@ enum patch_merge_type { struct clip_hparams { int32_t image_size; int32_t patch_size; - int32_t hidden_size; - int32_t n_intermediate; + int32_t n_embd; + int32_t n_ff; int32_t projection_dim; int32_t n_head; int32_t n_layer; @@ -191,12 +191,6 @@ struct clip_layer { struct ggml_tensor * ln_1_w = nullptr; struct ggml_tensor * ln_1_b = nullptr; - // ff - struct ggml_tensor * ff_i_w = nullptr; // legacy naming - struct ggml_tensor * ff_i_b = nullptr; // legacy naming - struct ggml_tensor * ff_o_w = nullptr; // legacy naming - struct ggml_tensor * ff_o_b = nullptr; // legacy naming - struct ggml_tensor * ff_up_w = nullptr; struct ggml_tensor * ff_up_b = nullptr; struct ggml_tensor * ff_gate_w = nullptr; @@ -204,9 +198,6 @@ struct clip_layer { struct ggml_tensor * ff_down_w = nullptr; struct ggml_tensor * ff_down_b = nullptr; - struct ggml_tensor * ff_g_w = NULL; - struct ggml_tensor * ff_g_b = NULL; - // layernorm 2 struct ggml_tensor * ln_2_w = nullptr; struct ggml_tensor * ln_2_b = nullptr; @@ -388,9 +379,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const int n_layer = hparams.n_layer; const float eps = hparams.eps; @@ -411,7 +402,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im ggml_set_input(inp_raw); struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); @@ -456,7 +447,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches); } // attention output @@ -473,14 +464,14 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b); // siglip uses gelu cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b); // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -504,11 +495,11 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im const int kernel_size = patches_per_image / tokens_per_side; embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); - embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size); + embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, n_embd, batch_size); // doing a pool2d to reduce the number of output tokens to 256 embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); - embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], n_embd, batch_size); embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); // apply norm before projection @@ -637,9 +628,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i const int n_patches_x = image_size_width / patch_size; const int n_patches_y = image_size_height / patch_size; const int num_patches = n_patches_x * n_patches_y; - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const int n_layer = hparams.n_layer; const float eps = hparams.eps; const int n_merge = hparams.spatial_merge_size; @@ -669,7 +660,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i ggml_set_input(pos_w); struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); struct ggml_tensor * embeddings = inp; @@ -710,7 +701,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches); cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur); } @@ -753,8 +744,8 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w); // reshape image tokens to 2D grid - cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y); - cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size] + cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y); + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd] cur = ggml_cont(ctx0, cur); // torch.nn.functional.unfold is just an im2col under the hood @@ -762,7 +753,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0); cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type); - // project to hidden_size + // project to n_embd cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur); embeddings = cur; @@ -785,9 +776,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i // arrangement of the [IMG_BREAK] token { // not efficient, but works - // the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows] + // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension - // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows] + // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows] const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y; const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x; @@ -827,9 +818,9 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (model.class_embedding ? 1 : 0); const int num_position_ids = num_positions * 4; // m-rope requires 4 dim per position - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const int n_layer = hparams.n_layer; const float eps = hparams.eps; @@ -864,14 +855,14 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, patches_h, batch_size); + n_embd * 2, patches_w / 2, patches_h, batch_size); inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); inp = ggml_reshape_3d( ctx0, inp, - hidden_size, patches_w * patches_h, batch_size); + n_embd, patches_w * patches_h, batch_size); if (model.patch_bias) { // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); @@ -904,11 +895,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ ggml_set_name(window_mask, "window_mask"); ggml_set_input(window_mask); - // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + // embeddings shape: [n_embd, patches_w * patches_h, batch_size] GGML_ASSERT(batch_size == 1); - embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); + embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * 4, patches_w * patches_h * batch_size / 4); embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, patches_w * patches_h, batch_size); } // loop over layers @@ -961,7 +952,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size); } // attention output @@ -978,11 +969,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ // mlp // ffn_up - auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); + auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); + cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_up_b); - auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); - cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); + auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur); + cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_gate_b); // TODO : only 2 of these 3 are actually used, should we remove one of them? if (ctx->use_gelu) { cur_gate = ggml_gelu_inplace(ctx0, cur_gate); @@ -994,8 +985,8 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ cur = ggml_mul(ctx0, cur_gate, cur_up); // ffn_down - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b); // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -1011,7 +1002,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); } - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1028,7 +1019,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ ggml_set_name(window_idx, "window_idx"); ggml_set_input(window_idx); - // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + // embeddings shape: [n_embd, patches_w * patches_h, batch_size] GGML_ASSERT(batch_size == 1); embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); embeddings = ggml_get_rows(ctx0, embeddings, window_idx); @@ -1074,9 +1065,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (model.class_embedding ? 1 : 0); const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions; - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; @@ -1114,17 +1105,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, patches_h, batch_size); + n_embd * 2, patches_w / 2, patches_h, batch_size); inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); inp = ggml_reshape_3d( ctx0, inp, - hidden_size, patches_w * patches_h, batch_size); + n_embd, patches_w * patches_h, batch_size); } else { - inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_reshape_3d(ctx0, inp, num_patches, n_embd, batch_size); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); } @@ -1137,7 +1128,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // concat class_embeddings and patch_embeddings if (model.class_embedding) { - embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, num_positions, batch_size); embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); @@ -1234,7 +1225,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size); } // attention output @@ -1252,8 +1243,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b); if (ctx->use_gelu) { cur = ggml_gelu_inplace(ctx0, cur); @@ -1263,8 +1254,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_gelu_quick_inplace(ctx0, cur); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b); // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -1496,9 +1487,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } { // attention - int hidden_size = clip_n_mmproj_embd(ctx); + int n_embd = clip_n_mmproj_embd(ctx); const int d_head = 128; - int n_head = hidden_size/d_head; + int n_head = n_embd/d_head; int num_query = 96; if (ctx->minicpmv_version == 2) { num_query = 96; @@ -1528,7 +1519,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); + KQV = ggml_cont_3d(ctx0, KQV, n_embd, num_query, batch_size); embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); } @@ -1571,7 +1562,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1696,9 +1687,9 @@ struct clip_model_loader { get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); - get_u32(KEY_N_EMBD, hparams.hidden_size); + get_u32(KEY_N_EMBD, hparams.n_embd); get_u32(KEY_N_HEAD, hparams.n_head); - get_u32(KEY_N_FF, hparams.n_intermediate); + get_u32(KEY_N_FF, hparams.n_ff); get_u32(KEY_N_BLOCK, hparams.n_layer); get_u32(KEY_PROJ_DIM, hparams.projection_dim); get_f32(KEY_LAYER_NORM_EPS, hparams.eps); @@ -1807,6 +1798,7 @@ struct clip_model_loader { } void load_tensors() { + auto & hparams = ctx_clip.vision_model.hparams; std::map tensor_offset; std::vector tensors_to_load; @@ -1860,8 +1852,8 @@ struct clip_model_loader { vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); // layers - vision_model.layers.resize(vision_model.hparams.n_layer); - for (int il = 0; il < vision_model.hparams.n_layer; ++il) { + vision_model.layers.resize(hparams.n_layer); + for (int il = 0; il < hparams.n_layer; ++il) { auto & layer = vision_model.layers[il]; layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight")); layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight")); @@ -1884,13 +1876,18 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); - // legacy naming (the in and out is reversed! don't ask me why) - layer.ff_i_w = layer.ff_down_w; - layer.ff_o_w = layer.ff_up_w; - layer.ff_g_w = layer.ff_gate_w; - layer.ff_i_b = layer.ff_down_b; - layer.ff_o_b = layer.ff_up_b; - layer.ff_g_b = layer.ff_gate_b; + // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here + // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! + if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) { + // swap up and down weights + ggml_tensor * tmp = layer.ff_up_w; + layer.ff_up_w = layer.ff_down_w; + layer.ff_down_w = tmp; + // swap up and down biases + tmp = layer.ff_up_b; + layer.ff_up_b = layer.ff_down_b; + layer.ff_down_b = tmp; + } } switch (ctx_clip.proj_type) { @@ -2904,7 +2901,7 @@ int32_t clip_get_patch_size(const struct clip_ctx * ctx) { } int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.hidden_size; + return ctx->vision_model.hparams.n_embd; } const char * clip_patch_merge_type(const struct clip_ctx * ctx) { diff --git a/tools/llava/mtmd-cli.cpp b/tools/llava/mtmd-cli.cpp index 474e7c4f8357e..e3db823799674 100644 --- a/tools/llava/mtmd-cli.cpp +++ b/tools/llava/mtmd-cli.cpp @@ -92,6 +92,10 @@ struct mtmd_cli_context { batch = llama_batch_init(params.n_batch, 0, 1); n_batch = params.n_batch; + if (!model || !lctx) { + exit(1); + } + if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) { LOG_ERR("Model does not have chat template.\n"); LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");