@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
189
189
return ubatch;
190
190
}
191
191
192
- void llama_sbatch::from_batch (const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
192
+ llama_sbatch::llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
193
GGML_ASSERT (batch.n_tokens >= 0 );
194
194
this ->batch = &batch;
195
195
this ->n_embd = n_embd;
@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
203
203
for (size_t i = 0 ; i < n_tokens; ++i) {
204
204
ids[i] = i;
205
205
}
206
+
206
207
if (simple_split) {
207
208
seq.resize (1 );
208
209
llama_sbatch_seq & s = seq[0 ];
@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
212
213
s.length = n_tokens;
213
214
return ;
214
215
}
216
+
215
217
std::sort (ids.begin (), ids.end (),
216
218
[&batch](size_t a, size_t b) {
217
219
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id [a] : 1 ;
@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
239
241
return n_seq_a > n_seq_b;
240
242
}
241
243
);
244
+
242
245
// init seq
243
246
llama_sbatch_seq * last_seq = nullptr ;
244
247
@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
262
265
seq.push_back (new_seq);
263
266
last_seq = &seq.back ();
264
267
}
268
+
265
269
// keep shared prompts first at the end, then sort by length descending.
266
270
std::sort (seq.begin (), seq.end (),
267
271
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
0 commit comments