Skip to content

Commit a4d365e

Browse files
committed
Special handling when inserting beacon tokens into the seq
1 parent a18cbff commit a4d365e

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

surya/foundation/__init__.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
274274
input_ids = processed_output.input_ids
275275
num_predicted_tokens += 1
276276
input_ids, valid_tokens = self.maybe_insert_beacon_tokens(input_ids, num_predicted_tokens)
277-
position_ids = position_ids[:, -1:] + torch.arange(input_ids.shape[1])
277+
# TODO we should only consider position_ids upto the valid range for each batch element
278+
position_ids = position_ids[:, -1:] + torch.arange(1, input_ids.shape[1] + 1)
278279

279280
new_input = ContinuousBatchInput(
280281
input_ids=input_ids,
@@ -285,6 +286,33 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
285286

286287
return new_input, processed_output
287288

289+
def pad_and_shift_input_ids_position_ids(
290+
self,
291+
input_ids: torch.Tensor,
292+
position_ids: torch.Tensor,
293+
new_seq_len: int,
294+
) -> Tuple[torch.Tensor, torch.Tensor]:
295+
"""
296+
Pads new_input_ids to match the new seq len
297+
and creates updated position_ids based on current_position_ids' last position.
298+
299+
Returns:
300+
padded_input_ids (torch.Tensor): [batch_size, current_seq_len]
301+
updated_position_ids (torch.Tensor): [batch_size, current_seq_len]
302+
"""
303+
assert input_ids.shape[1] == 1, "During prefill the new input_ids must be of length 1"
304+
305+
if new_seq_len == input_ids.shape[1]:
306+
return input_ids, position_ids[:, -1:] + 1
307+
308+
pad_len = new_seq_len - 1
309+
padded_input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=self.device_pad_token)
310+
311+
# Create updated position_ids starting from the last position + 1, increasing by 1 each step
312+
updated_position_ids = position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device)
313+
314+
return padded_input_ids, updated_position_ids
315+
288316
def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
289317
logger.debug(f"Prefilling {self.num_empty_slots} slots")
290318
prompts: List[RecognitionPrompt] = [
@@ -337,7 +365,6 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
337365
valid_tokens = torch.ones((input_ids.shape[0], 1), device=self.model.device) # No extra tokens during prefill
338366
processed_outputs = self.process_outputs(outputs, valid_tokens=valid_tokens)
339367
# Update to account for the newly generated tokens
340-
position_ids = position_ids[:, -1:] + 1
341368
self.kv_cache.attention_mask[idxs_to_merge] = attention_mask[:len(idxs_to_merge)]
342369

343370
# Find text lenghts of each
@@ -369,11 +396,13 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
369396
)
370397

371398
# Merging inputs for next steps
372-
# TODO If the current inputs contain padded position ids or input ids, we have to merge them with padding
373399
current_input_ids = current_inputs.input_ids
374-
current_input_ids[idxs_to_merge] = processed_outputs.input_ids
375-
376400
current_position_ids = current_inputs.position_ids
401+
402+
input_ids, position_ids = self.pad_and_shift_input_ids_position_ids(
403+
processed_outputs.input_ids, position_ids, new_seq_len=current_input_ids.shape[1]
404+
)
405+
current_input_ids[idxs_to_merge] = input_ids
377406
current_position_ids[idxs_to_merge] = position_ids
378407

379408
current_valid_tokens = current_inputs.valid_tokens

0 commit comments

Comments
 (0)