@@ -274,7 +274,8 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
274
274
input_ids = processed_output .input_ids
275
275
num_predicted_tokens += 1
276
276
input_ids , valid_tokens = self .maybe_insert_beacon_tokens (input_ids , num_predicted_tokens )
277
- position_ids = position_ids [:, - 1 :] + torch .arange (input_ids .shape [1 ])
277
+ # TODO we should only consider position_ids upto the valid range for each batch element
278
+ position_ids = position_ids [:, - 1 :] + torch .arange (1 , input_ids .shape [1 ] + 1 )
278
279
279
280
new_input = ContinuousBatchInput (
280
281
input_ids = input_ids ,
@@ -285,6 +286,33 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
285
286
286
287
return new_input , processed_output
287
288
289
+ def pad_and_shift_input_ids_position_ids (
290
+ self ,
291
+ input_ids : torch .Tensor ,
292
+ position_ids : torch .Tensor ,
293
+ new_seq_len : int ,
294
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
295
+ """
296
+ Pads new_input_ids to match the new seq len
297
+ and creates updated position_ids based on current_position_ids' last position.
298
+
299
+ Returns:
300
+ padded_input_ids (torch.Tensor): [batch_size, current_seq_len]
301
+ updated_position_ids (torch.Tensor): [batch_size, current_seq_len]
302
+ """
303
+ assert input_ids .shape [1 ] == 1 , "During prefill the new input_ids must be of length 1"
304
+
305
+ if new_seq_len == input_ids .shape [1 ]:
306
+ return input_ids , position_ids [:, - 1 :] + 1
307
+
308
+ pad_len = new_seq_len - 1
309
+ padded_input_ids = torch .nn .functional .pad (input_ids , (0 , pad_len ), value = self .device_pad_token )
310
+
311
+ # Create updated position_ids starting from the last position + 1, increasing by 1 each step
312
+ updated_position_ids = position_ids [:, - 1 :] + torch .arange (1 , new_seq_len + 1 , device = self .model .device )
313
+
314
+ return padded_input_ids , updated_position_ids
315
+
288
316
def prefill (self , current_inputs : Optional [ContinuousBatchInput ] = None ):
289
317
logger .debug (f"Prefilling { self .num_empty_slots } slots" )
290
318
prompts : List [RecognitionPrompt ] = [
@@ -337,7 +365,6 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
337
365
valid_tokens = torch .ones ((input_ids .shape [0 ], 1 ), device = self .model .device ) # No extra tokens during prefill
338
366
processed_outputs = self .process_outputs (outputs , valid_tokens = valid_tokens )
339
367
# Update to account for the newly generated tokens
340
- position_ids = position_ids [:, - 1 :] + 1
341
368
self .kv_cache .attention_mask [idxs_to_merge ] = attention_mask [:len (idxs_to_merge )]
342
369
343
370
# Find text lenghts of each
@@ -369,11 +396,13 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
369
396
)
370
397
371
398
# Merging inputs for next steps
372
- # TODO If the current inputs contain padded position ids or input ids, we have to merge them with padding
373
399
current_input_ids = current_inputs .input_ids
374
- current_input_ids [idxs_to_merge ] = processed_outputs .input_ids
375
-
376
400
current_position_ids = current_inputs .position_ids
401
+
402
+ input_ids , position_ids = self .pad_and_shift_input_ids_position_ids (
403
+ processed_outputs .input_ids , position_ids , new_seq_len = current_input_ids .shape [1 ]
404
+ )
405
+ current_input_ids [idxs_to_merge ] = input_ids
377
406
current_position_ids [idxs_to_merge ] = position_ids
378
407
379
408
current_valid_tokens = current_inputs .valid_tokens
0 commit comments