|
10 | 10 | #include <iomanip>
|
11 | 11 | #include <iostream>
|
12 | 12 | #include <map>
|
| 13 | +#include <numeric> |
13 | 14 | #include <random>
|
14 | 15 | #include <sstream>
|
15 | 16 | #include <stdexcept>
|
@@ -345,7 +346,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
345 | 346 | d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
346 | 347 | d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
347 | 348 |
|
348 |
| - if (m_supportsEmbedding) |
| 349 | + if (isEmbedding) |
349 | 350 | d_ptr->ctx_params.embeddings = true;
|
350 | 351 |
|
351 | 352 | d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
|
@@ -612,22 +613,22 @@ struct EmbModelGroup {
|
612 | 613 | std::vector<const char *> names;
|
613 | 614 | };
|
614 | 615 |
|
615 |
| -static const EmbModelSpec NOPREFIX_SPEC {nullptr, nullptr}; |
| 616 | +static const EmbModelSpec NOPREFIX_SPEC {"", ""}; |
616 | 617 | static const EmbModelSpec NOMIC_SPEC {"search_document", "search_query", {"clustering", "classification"}};
|
617 | 618 | static const EmbModelSpec E5_SPEC {"passage", "query"};
|
618 | 619 |
|
619 | 620 | static const EmbModelSpec NOMIC_1_5_SPEC {
|
620 |
| - "search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]" |
| 621 | + "search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]", |
621 | 622 | };
|
622 | 623 | static const EmbModelSpec LLM_EMBEDDER_SPEC {
|
623 | 624 | "Represent this document for retrieval",
|
624 | 625 | "Represent this query for retrieving relevant documents",
|
625 | 626 | };
|
626 | 627 | static const EmbModelSpec BGE_SPEC {
|
627 |
| - nullptr, "Represent this sentence for searching relevant passages", |
| 628 | + "", "Represent this sentence for searching relevant passages", |
628 | 629 | };
|
629 | 630 | static const EmbModelSpec E5_MISTRAL_SPEC {
|
630 |
| - nullptr, "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery", |
| 631 | + "", "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery", |
631 | 632 | };
|
632 | 633 |
|
633 | 634 | static const EmbModelGroup EMBEDDING_MODEL_SPECS[] {
|
@@ -738,18 +739,20 @@ void LLamaModel::embedInternal(
|
738 | 739 | const llama_token bos_token = llama_token_bos(d_ptr->model);
|
739 | 740 | const llama_token eos_token = llama_token_eos(d_ptr->model);
|
740 | 741 |
|
741 |
| - assert(shouldAddBOS()); |
742 |
| - bool addEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM; |
| 742 | + bool useBOS = shouldAddBOS(); |
| 743 | + bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM; |
743 | 744 |
|
744 | 745 | // no EOS, optional BOS
|
745 |
| - auto tokenize = [this, addEOS](std::string text, TokenString &tokens, bool addBOS) { |
746 |
| - if (!text.empty() && text[0] != ' ') |
| 746 | + auto tokenize = [this, useBOS, useEOS, eos_token](std::string text, TokenString &tokens, bool wantBOS) { |
| 747 | + if (!text.empty() && text[0] != ' ') { |
747 | 748 | text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
|
| 749 | + } |
| 750 | + wantBOS &= useBOS; |
748 | 751 |
|
749 | 752 | tokens.resize(text.length()+4);
|
750 |
| - int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), addBOS, false); |
751 |
| - assert(addEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token)); |
752 |
| - tokens.resize(n_tokens - addEOS); // erase EOS/SEP |
| 753 | + int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false); |
| 754 | + assert(useEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token)); |
| 755 | + tokens.resize(n_tokens - useEOS); // erase EOS/SEP |
753 | 756 | };
|
754 | 757 |
|
755 | 758 | // tokenize the texts
|
@@ -784,7 +787,7 @@ void LLamaModel::embedInternal(
|
784 | 787 | }
|
785 | 788 |
|
786 | 789 | const uint32_t n_batch = llama_n_batch(d_ptr->ctx);
|
787 |
| - const uint32_t max_len = n_batch - (prefixTokens.size() + addEOS); // minus BOS/CLS and EOS/SEP |
| 790 | + const uint32_t max_len = n_batch - (prefixTokens.size() + useEOS); // minus BOS/CLS and EOS/SEP |
788 | 791 | if (chunkOverlap >= max_len) {
|
789 | 792 | throw std::logic_error("max chunk length of " + std::to_string(max_len) + " is smaller than overlap of " +
|
790 | 793 | std::to_string(chunkOverlap) + " tokens");
|
|
0 commit comments