Skip to content

Commit cf348a6

Browse files
main : add option to save full output to session (LostRuins#1338)
* main : add option to save full output to session * split behavior into --session and --prompt-cache * restore original implementation with new names * PR comments * move the check for incompatible parameters to gpt_params_parse * Fix whitespace Co-authored-by: DannyDaemonic <[email protected]> --------- Co-authored-by: DannyDaemonic <[email protected]>
1 parent e6a46b0 commit cf348a6

File tree

4 files changed

+30
-18
lines changed

4 files changed

+30
-18
lines changed

examples/common.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
118118
params.prompt = argv[i];
119119
} else if (arg == "-e") {
120120
escape_prompt = true;
121-
} else if (arg == "--session") {
121+
} else if (arg == "--prompt-cache") {
122122
if (++i >= argc) {
123123
invalid_param = true;
124124
break;
125125
}
126-
params.path_session = argv[i];
126+
params.path_prompt_cache = argv[i];
127+
} else if (arg == "--prompt-cache-all") {
128+
params.prompt_cache_all = true;
127129
} else if (arg == "-f" || arg == "--file") {
128130
if (++i >= argc) {
129131
invalid_param = true;
@@ -342,6 +344,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
342344
gpt_print_usage(argc, argv, default_params);
343345
exit(1);
344346
}
347+
if (params.prompt_cache_all &&
348+
(params.interactive || params.interactive_first ||
349+
params.instruct || params.antiprompt.size())) {
350+
fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
351+
gpt_print_usage(argc, argv, default_params);
352+
exit(1);
353+
}
345354
if (escape_prompt) {
346355
process_escapes(params.prompt);
347356
}
@@ -367,7 +376,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
367376
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
368377
fprintf(stderr, " prompt to start generation with (default: empty)\n");
369378
fprintf(stderr, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
370-
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
379+
fprintf(stderr, " --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n");
380+
fprintf(stderr, " --prompt-cache-all if specified, saves user input and generations to cache as well.\n");
381+
fprintf(stderr, " not supported with --interactive or other interactive options\n");
371382
fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
372383
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
373384
fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");

examples/common.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ struct gpt_params {
4646

4747
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
4848
std::string prompt = "";
49-
std::string path_session = ""; // path to file for saving/loading model eval state
50-
std::string input_prefix = ""; // string to prefix user inputs with
51-
std::string input_suffix = ""; // string to suffix user inputs with
49+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
50+
std::string input_prefix = ""; // string to prefix user inputs with
51+
std::string input_suffix = ""; // string to suffix user inputs with
5252
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
5353

5454
std::string lora_adapter = ""; // lora adapter path
@@ -58,6 +58,7 @@ struct gpt_params {
5858
bool random_prompt = false; // do not randomize prompt if none provided
5959
bool use_color = false; // use color to distinguish generations and inputs
6060
bool interactive = false; // interactive mode
61+
bool prompt_cache_all = false; // save user input and generations to prompt cache
6162

6263
bool embedding = false; // get only sentence embedding
6364
bool interactive_first = false; // wait for user input immediately

examples/main/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ These options help improve the performance and memory usage of the LLaMA models.
270270

271271
- `-b N, --batch_size N`: Set the batch size for prompt processing (default: 512). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
272272

273-
### Session Caching
273+
### Prompt Caching
274274

275-
- `--session FNAME`: Specify a file to load/save the session, which caches the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The session file is created during the first run and is reused in subsequent runs. If you change your prompt such that 75% or less of the session is reusable, the existing session file will be overwritten with a new, updated version to maintain optimal performance.
275+
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs.
276276

277277
### Quantization
278278

examples/main/main.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ int main(int argc, char ** argv) {
139139
// Add a space in front of the first character to match OG llama tokenizer behavior
140140
params.prompt.insert(0, 1, ' ');
141141

142-
std::string path_session = params.path_session;
142+
std::string path_session = params.path_prompt_cache;
143143
std::vector<llama_token> session_tokens;
144144

145145
if (!path_session.empty()) {
@@ -292,14 +292,9 @@ int main(int argc, char ** argv) {
292292
is_interacting = params.interactive_first;
293293
}
294294

295-
bool is_antiprompt = false;
296-
bool input_echo = true;
297-
298-
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
299-
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
300-
// initial prompt so it doesn't need to be an exact match.
301-
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
302-
295+
bool is_antiprompt = false;
296+
bool input_echo = true;
297+
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
303298

304299
int n_past = 0;
305300
int n_remain = params.n_predict;
@@ -328,7 +323,7 @@ int main(int argc, char ** argv) {
328323
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
329324

330325
// stop saving session if we run out of context
331-
path_session = "";
326+
path_session.clear();
332327

333328
//printf("\n---\n");
334329
//printf("resetting: '");
@@ -603,6 +598,11 @@ int main(int argc, char ** argv) {
603598
}
604599
}
605600

601+
if (!path_session.empty() && params.prompt_cache_all) {
602+
fprintf(stderr, "\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
603+
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
604+
}
605+
606606
llama_print_timings(ctx);
607607
llama_free(ctx);
608608

0 commit comments

Comments
 (0)