Skip to content

Commit 8464ee3

Browse files
committed
take out attention_type; add in llama_set_embeddings
1 parent 56d9aee commit 8464ee3

File tree

5 files changed

+19
-41
lines changed

5 files changed

+19
-41
lines changed

common/common.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -532,17 +532,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
532532
else { invalid_param = true; }
533533
return true;
534534
}
535-
if (arg == "--attention") {
536-
if (++i >= argc) {
537-
invalid_param = true;
538-
return true;
539-
}
540-
std::string value(argv[i]);
541-
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
542-
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
543-
else { invalid_param = true; }
544-
return true;
545-
}
546535
if (arg == "--defrag-thold" || arg == "-dt") {
547536
if (++i >= argc) {
548537
invalid_param = true;
@@ -1457,8 +1446,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14571446
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
14581447
printf(" --pooling {none,mean,cls,last}\n");
14591448
printf(" pooling type for embeddings, use model default if unspecified\n");
1460-
printf(" --attn-type {causal,non-causal}\n");
1461-
printf(" attention type for generation, use model default if unspecified\n");
14621449
printf(" -dt N, --defrag-thold N\n");
14631450
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
14641451
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -2056,7 +2043,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
20562043
cparams.yarn_beta_slow = params.yarn_beta_slow;
20572044
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
20582045
cparams.pooling_type = params.pooling_type;
2059-
cparams.attention_type = params.attention_type;
20602046
cparams.defrag_thold = params.defrag_thold;
20612047
cparams.cb_eval = params.cb_eval;
20622048
cparams.cb_eval_user_data = params.cb_eval_user_data;

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ struct gpt_params {
9595

9696
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
9797
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
98-
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type
9998

10099
// // sampling parameters
101100
struct llama_sampling_params sparams;

examples/gritlm/gritlm.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4444

4545
// clear previous kv_cache values (irrelevant for embeddings)
4646
llama_kv_cache_clear(ctx);
47+
llama_set_embeddings(ctx, true);
48+
llama_set_causal_attn(ctx, false);
4749

4850
// run model
4951
llama_decode(ctx, batch);
@@ -97,6 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
9799
llama_token eos_token = llama_token_eos(mdl);
98100

99101
llama_kv_cache_clear(ctx);
102+
llama_set_embeddings(ctx, false);
103+
llama_set_causal_attn(ctx, true);
104+
100105
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
101106

102107
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@@ -163,13 +168,7 @@ int main(int argc, char * argv[]) {
163168
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
164169

165170
// create generation context
166-
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);
167-
168-
// create embedding context
169-
cparams.embeddings = true;
170-
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
171-
cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL;
172-
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
171+
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
173172

174173
// ### Embedding/Representation ###
175174
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -187,8 +186,8 @@ int main(int argc, char * argv[]) {
187186
};
188187

189188
// No need to add instruction for retrieval documents
190-
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
191-
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));
189+
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
190+
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
192191

193192
const int n_embd = llama_n_embd(mdl);
194193

@@ -207,11 +206,10 @@ int main(int argc, char * argv[]) {
207206
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
208207
{
209208
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
210-
std::string response = generate(ctx_gen, prompt, true);
209+
std::string response = generate(ctx, prompt, true);
211210
}
212211

213-
llama_free(ctx_gen);
214-
llama_free(ctx_emb);
212+
llama_free(ctx);
215213
llama_free_model(mdl);
216214
llama_backend_free();
217215

llama.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15277,7 +15277,6 @@ struct llama_context_params llama_context_default_params() {
1527715277
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
1527815278
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
1527915279
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
15280-
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
1528115280
/*.rope_freq_base =*/ 0.0f,
1528215281
/*.rope_freq_scale =*/ 0.0f,
1528315282
/*.yarn_ext_factor =*/ -1.0f,
@@ -15514,12 +15513,7 @@ struct llama_context * llama_new_context_with_model(
1551415513
}
1551515514

1551615515
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
15517-
15518-
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
15519-
cparams.causal_attn = hparams.causal_attn;
15520-
} else {
15521-
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
15522-
}
15516+
cparams.causal_attn = hparams.causal_attn;
1552315517

1552415518
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
1552515519
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -17232,6 +17226,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
1723217226
ctx->abort_callback_data = abort_callback_data;
1723317227
}
1723417228

17229+
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
17230+
ctx->cparams.embeddings = embeddings;
17231+
}
17232+
1723517233
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
1723617234
ctx->cparams.causal_attn = causal_attn;
1723717235
}

llama.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,6 @@ extern "C" {
161161
LLAMA_POOLING_TYPE_LAST = 3,
162162
};
163163

164-
enum llama_attention_type {
165-
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
166-
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
167-
LLAMA_ATTENTION_TYPE_NONCAUSAL = 1,
168-
};
169-
170164
enum llama_split_mode {
171165
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
172166
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -282,7 +276,6 @@ extern "C" {
282276

283277
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
284278
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
285-
enum llama_attention_type attention_type; // causal, non-causal, or unspecified
286279

287280
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
288281
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -766,6 +759,10 @@ extern "C" {
766759
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
767760
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
768761

762+
// Set whether the model is in embeddings model or not
763+
// If true, embeddings will be returned but logits will not
764+
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
765+
769766
// Set whether to use causal attention or not
770767
// If set to true, the model will only attend to the past tokens
771768
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

0 commit comments

Comments
 (0)