File tree Expand file tree Collapse file tree 5 files changed +26
-5
lines changed Expand file tree Collapse file tree 5 files changed +26
-5
lines changed Original file line number Diff line number Diff line change @@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166
166
167
167
// note: tracking the other way around is not necessary for now
168
168
// seq_cpl[s0][s1] = true;
169
+
170
+ has_cpl = true ;
169
171
}
170
172
}
171
173
}
@@ -466,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
466
468
return ubatch_add (idxs, idxs.size (), false );
467
469
}
468
470
469
- llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
471
+ llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch, bool sequential) {
472
+ if (sequential && has_cpl) {
473
+ LLAMA_LOG_ERROR (" %s: sequential split is not supported when there are coupled sequences in the input batch\n " , __func__);
474
+
475
+ return {};
476
+ }
477
+
470
478
std::vector<seq_set_t > cur_seq_set;
471
479
480
+ llama_seq_id last_seq_id = -1 ;
481
+
472
482
// determine the non-overlapping sequence sets participating in this ubatch
473
483
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
474
484
if (used[i]) {
@@ -485,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
485
495
}
486
496
}
487
497
498
+ // accept only increasing sequence ids
499
+ if (sequential) {
500
+ add = add && (cur_seq_set.empty () || batch.seq_id [i][0 ] == last_seq_id + 1 );
501
+ }
502
+
488
503
if (add) {
489
504
cur_seq_set.push_back (seq_set[i]);
490
505
506
+ last_seq_id = batch.seq_id [i][0 ];
507
+
491
508
if (cur_seq_set.size () > n_ubatch) {
492
509
break ;
493
510
}
Original file line number Diff line number Diff line change @@ -70,7 +70,8 @@ class llama_batch_allocr {
70
70
llama_ubatch split_simple (uint32_t n_ubatch);
71
71
72
72
// make ubatches of equal-length sequences sets
73
- llama_ubatch split_equal (uint32_t n_ubatch);
73
+ // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
74
+ llama_ubatch split_equal (uint32_t n_ubatch, bool sequential);
74
75
75
76
// sequence-set-wise split - each ubatch contains a single sequence-set
76
77
llama_ubatch split_seq (uint32_t n_ubatch);
@@ -113,6 +114,9 @@ class llama_batch_allocr {
113
114
using pos_set_t = std::set<llama_pos>;
114
115
using seq_cpl_t = std::vector<bool >;
115
116
117
+ // helper flag to quickly determine if there are any coupled sequences in the batch
118
+ bool has_cpl;
119
+
116
120
std::vector<pos_set_t > seq_pos; // seq_pos[s]: the set of positions in sequence s
117
121
std::vector<seq_cpl_t > seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
118
122
Original file line number Diff line number Diff line change @@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
140
140
141
141
std::vector<llama_ubatch> ubatches;
142
142
while (true ) {
143
- auto ubatch = balloc.split_equal (n_ubatch);
143
+ auto ubatch = balloc.split_equal (n_ubatch, false );
144
144
145
145
if (ubatch.n_tokens == 0 ) {
146
146
break ;
Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
70
70
// if all tokens are output, split by sequence
71
71
ubatch = balloc.split_seq (n_ubatch);
72
72
} else {
73
- ubatch = balloc.split_equal (n_ubatch);
73
+ ubatch = balloc.split_equal (n_ubatch, false );
74
74
}
75
75
76
76
if (ubatch.n_tokens == 0 ) {
Original file line number Diff line number Diff line change @@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
374
374
// if all tokens are output, split by sequence
375
375
ubatch = balloc.split_seq (n_ubatch);
376
376
} else {
377
- ubatch = balloc.split_equal (n_ubatch);
377
+ ubatch = balloc.split_equal (n_ubatch, false );
378
378
}
379
379
380
380
if (balloc.get_n_used () < balloc.get_n_tokens ()) {
You can’t perform that action at this time.
0 commit comments