24
24
#include < signal.h>
25
25
#endif
26
26
27
- static bool g_is_generating = false ;
27
+ // volatile, because of signal being an interrupt
28
+ static volatile bool g_is_generating = false ;
29
+ static volatile bool g_is_interrupted = false ;
28
30
29
31
/* *
30
32
* Please note that this is NOT a production-ready stuff.
@@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
50
52
g_is_generating = false ;
51
53
} else {
52
54
console::cleanup ();
53
- LOG (" \n Interrupted by user\n " );
54
- _exit (130 );
55
+ if (g_is_interrupted) {
56
+ _exit (1 );
57
+ }
58
+ g_is_interrupted = true ;
55
59
}
56
60
}
57
61
}
@@ -167,7 +171,7 @@ struct decode_embd_batch {
167
171
static int generate_response (mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
168
172
llama_tokens generated_tokens;
169
173
for (int i = 0 ; i < n_predict; i++) {
170
- if (i > n_predict || !g_is_generating) {
174
+ if (i > n_predict || !g_is_generating || g_is_interrupted ) {
171
175
printf (" \n " );
172
176
break ;
173
177
}
@@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
184
188
printf (" %s" , common_token_to_piece (ctx.lctx , token_id).c_str ());
185
189
fflush (stdout);
186
190
191
+ if (g_is_interrupted) {
192
+ printf (" \n " );
193
+ break ;
194
+ }
195
+
187
196
// eval the token
188
197
common_batch_clear (ctx.batch );
189
198
common_batch_add (ctx.batch , token_id, ctx.n_past ++, {0 }, true );
@@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
219
228
text.add_special = add_bos;
220
229
text.parse_special = true ;
221
230
mtmd_input_chunks chunks;
231
+
232
+ if (g_is_interrupted) return 0 ;
233
+
222
234
int32_t res = mtmd_tokenize (ctx.ctx_vision .get (), chunks, text, bitmaps);
223
235
if (res != 0 ) {
224
236
LOG_ERR (" Unable to tokenize prompt, res = %d\n " , res);
@@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
276
288
#endif
277
289
}
278
290
291
+ if (g_is_interrupted) return 130 ;
292
+
279
293
if (is_single_turn) {
280
294
g_is_generating = true ;
281
295
if (params.prompt .find (" <__image__>" ) == std::string::npos) {
@@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
287
301
if (eval_message (ctx, msg, params.image , true )) {
288
302
return 1 ;
289
303
}
290
- if (generate_response (ctx, smpl, n_predict)) {
304
+ if (!g_is_interrupted && generate_response (ctx, smpl, n_predict)) {
291
305
return 1 ;
292
306
}
293
307
@@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
302
316
std::vector<std::string> images_fname;
303
317
std::string content;
304
318
305
- while (true ) {
319
+ while (!g_is_interrupted ) {
306
320
g_is_generating = false ;
307
321
LOG (" \n > " );
308
322
console::set_display (console::user_input);
309
323
std::string line;
310
324
console::readline (line, false );
325
+ if (g_is_interrupted) break ;
311
326
console::set_display (console::reset);
312
327
line = string_strip (line);
313
328
if (line.empty ()) {
@@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
335
350
msg.role = " user" ;
336
351
msg.content = content;
337
352
int ret = eval_message (ctx, msg, images_fname, is_first_msg);
353
+ if (g_is_interrupted) break ;
338
354
if (ret == 2 ) {
339
355
// non-fatal error
340
356
images_fname.clear ();
@@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
352
368
is_first_msg = false ;
353
369
}
354
370
}
371
+ if (g_is_interrupted) LOG (" \n Interrupted by user\n " );
355
372
llama_perf_context_print (ctx.lctx );
356
- return 0 ;
373
+ return g_is_interrupted ? 130 : 0 ;
357
374
}
0 commit comments