Skip to content

Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture #12466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2025

Conversation

manyoso
Copy link
Contributor

@manyoso manyoso commented Mar 19, 2025

  • Adds MoE-based embedding model supporting multilingual embeddings.
  • Selects architecture variant based on hyperparameter detection (MoE layers).
  • Removes unnecessary subclass initialization checks for clarity.

https://www.nomic.ai/blog/posts/nomic-embed-text-v2

Make sure to read the contributing guidelines before submitting a PR

@github-actions github-actions bot added the python python script changes label Mar 19, 2025
@manyoso manyoso marked this pull request as draft March 19, 2025 13:46
@manyoso manyoso marked this pull request as ready for review March 19, 2025 15:54
ngxson

This comment was marked as resolved.

@manyoso manyoso marked this pull request as draft March 19, 2025 17:58
@manyoso

This comment was marked as resolved.

manyoso and others added 2 commits April 23, 2025 16:02
- Adds MoE-based embedding model supporting multilingual embeddings.
- Selects architecture variant based on hyperparameter detection (MoE layers).
- Removes unnecessary subclass initialization checks for clarity.

https://www.nomic.ai/blog/posts/nomic-embed-text-v2

Co-authored-by: Jared Van Bortel <[email protected]>
@cebtenzzre cebtenzzre marked this pull request as ready for review April 23, 2025 20:04
@cebtenzzre cebtenzzre requested a review from ngxson April 23, 2025 20:04
@cebtenzzre
Copy link
Collaborator

The MoE model is now using the correct tokenizer (XLMRoberta), and norm_w is now correctly set to false. Getting an MSE of about 6e-7 compared to the HF embeddings with a simple prompt, so the implementation should be ready to use.

@anudit
Copy link

anudit commented Apr 28, 2025

Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just a small comment

@@ -907,31 +907,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_weighted", il);
}

ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * tmp = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still call this up, right? There is no other places where we re-assign another value for tmp

Copy link
Collaborator

@cebtenzzre cebtenzzre Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason to call it tmp would be that that's what (non-moe) build_ffn calls it, which makes it easier to compare the two functions. In that function, up, down, and gate refer to weight tensors, and not output tensors. But up is fine here.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the documentation, I figured the following usage example for the search_document instruction:

llama-server \
  -m models/nomic-embed-text-v2-moe/ggml-model-f16.gguf \
  --embeddings

curl http://localhost:8080/v1/embeddings \
  -H "Content-Type: application/json" \
  -d '{"input": ["search_document: Hello!", "search_document: ¡Hola!", "search_document: Goodbye"]}' | jq

Is this correct?

Could you also show an example of how the search_query is to be used?

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Apr 28, 2025

Is this correct?

Yes, that's the basic way the prefixes should be used.

Could you also show an example of how the search_query is to be used?

It's a prefix like search_document but typically used to embed the query in a retrieval pipeline. This is easier to demonstrate in Python:

import requests
def dot(va, vb):
    return sum(a*b for a, b in zip(va, vb))
def embed(texts):
    resp = requests.post('http://localhost:8080/v1/embeddings', json=dict(input=texts)).json()
    return [d['embedding'] for d in resp['data']]

docs = ['嵌入很酷', '骆驼很酷']  # 'embeddings are cool', 'llamas are cool'
docs_embed = embed(['search_document: '+d for d in docs])

query = '跟我讲讲嵌入'  # 'tell me about embeddings'
query_embed = embed(['search_query: '+query])[0]
print(f'query: {query!r}')
for d, e in zip(docs, docs_embed):
    print(f'similarity {dot(query_embed, e):.2f}: {d!r}')

Output:

query: '跟我讲讲嵌入'
similarity 0.48: '嵌入很酷'
similarity 0.19: '骆驼很酷'

search_query is used with a query to retrieve texts that help inform the response (RAG). The query should prefixed with search_document instead when the goal is to find the most semantically similar text.

@ggerganov ggerganov merged commit 5f5e39e into ggml-org:master Apr 28, 2025
41 of 51 checks passed
@ggerganov
Copy link
Member

@cebtenzzre The readme at https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe says that the max sequence length is 512:

image

Is this correct, or is it 2048 as specified in the model configuration?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants