Skip to content

Commit 5c7f088

Browse files
committed
Get multi-token predictions from the model
1 parent 660deb0 commit 5c7f088

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

surya/common/surya/__init__.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,31 @@ def get_2d_learned_embeddings(
304304
all_embeddings, dim=0
305305
) # Shape is num_image_tokens x embed_dim
306306

307+
def get_logits(self, hidden_states):
308+
assert hidden_states.shape[1] == 1, "Multi output predictions only applied on the last token"
309+
310+
all_lm_logits = []
311+
all_bbox_logits = []
312+
313+
current_hidden = hidden_states
314+
315+
# Loop includes initial prediction (i=0) plus multi_out_distance additional predictions
316+
for i in range(self.config.multi_output_distance + 1):
317+
if i > 0:
318+
current_hidden = self.multi_output_projections[i-1](current_hidden)
319+
320+
lm_logits = self.lm_head(current_hidden)
321+
bbox_logits = F.sigmoid(self.bbox_head(current_hidden))
322+
323+
all_lm_logits.append(lm_logits)
324+
all_bbox_logits.append(bbox_logits)
325+
326+
# Concatenate along sequence dimension (dim=1)
327+
final_lm_logits = torch.cat(all_lm_logits, dim=1)
328+
final_bbox_logits = torch.cat(all_bbox_logits, dim=1)
329+
330+
return final_lm_logits, final_bbox_logits
331+
307332
def forward(
308333
self,
309334
input_ids=None,
@@ -317,7 +342,6 @@ def forward(
317342
output_hidden_states=False,
318343
output_attentions=False,
319344
use_cache=False,
320-
logits_to_keep=None,
321345
encoder_chunk_size=None,
322346
cache_idxs=None,
323347
num_valid_tokens=None,
@@ -386,12 +410,9 @@ def forward(
386410

387411
hidden_states = outputs.last_hidden_state
388412
# Only keep the last `logits_to_keep` logits, should bring down memory usage during inference
389-
if logits_to_keep is not None:
390-
hidden_states = hidden_states[:, -logits_to_keep:, :]
391-
413+
hidden_states = hidden_states[:, -1:, :]
392414
hidden_states = hidden_states.contiguous()
393-
bbox_logits = F.sigmoid(self.bbox_head(hidden_states))
394-
lm_logits = self.lm_head(hidden_states)
415+
lm_logits, bbox_logits = self.get_logits(hidden_states)
395416

396417
return SuryaModelOutput(
397418
bbox_logits=bbox_logits,

surya/foundation/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,11 @@ def process_outputs(self, outputs: SuryaModelOutput, num_valid_tokens: torch.Ten
180180
lm_logits = outputs["lm_logits"].float() # shape: [B, T, V]
181181
bbox_logits = outputs["bbox_logits"].float() # shape: [B, T, D]
182182

183-
next_token_logits = lm_logits[:, -1:, :]
184-
next_bbox_logits = bbox_logits[:, -1:, :]
183+
# We make multitoken predictions - Currently only considering the first predicted token
184+
# TODO Add support for using all the predictions
185+
# TODO This requires a change to the beacon token logic
186+
next_token_logits = lm_logits[:, :1, :]
187+
next_bbox_logits = bbox_logits[:, :1, :]
185188

186189
# Get predictions
187190
preds = torch.argmax(next_token_logits, dim=-1) # shape: [B, 1]
@@ -263,8 +266,6 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
263266
position_ids=position_ids,
264267
use_cache=True,
265268
past_key_values=self.kv_cache,
266-
# We may pass multiple input ids per batch element (right padded) and we need the original size to index into them
267-
logits_to_keep=None,
268269
prefill=False,
269270
num_valid_tokens=num_valid_tokens
270271
)
@@ -371,7 +372,6 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
371372
inputs_embeds=None,
372373
past_key_values=self.kv_cache,
373374
use_cache=True,
374-
logits_to_keep=1,
375375
encoder_chunk_size=self.get_encoder_chunk_size(),
376376
cache_idxs=idxs_to_merge,
377377
prefill=True,

0 commit comments

Comments
 (0)