@@ -304,6 +304,31 @@ def get_2d_learned_embeddings(
304
304
all_embeddings , dim = 0
305
305
) # Shape is num_image_tokens x embed_dim
306
306
307
+ def get_logits (self , hidden_states ):
308
+ assert hidden_states .shape [1 ] == 1 , "Multi output predictions only applied on the last token"
309
+
310
+ all_lm_logits = []
311
+ all_bbox_logits = []
312
+
313
+ current_hidden = hidden_states
314
+
315
+ # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions
316
+ for i in range (self .config .multi_output_distance + 1 ):
317
+ if i > 0 :
318
+ current_hidden = self .multi_output_projections [i - 1 ](current_hidden )
319
+
320
+ lm_logits = self .lm_head (current_hidden )
321
+ bbox_logits = F .sigmoid (self .bbox_head (current_hidden ))
322
+
323
+ all_lm_logits .append (lm_logits )
324
+ all_bbox_logits .append (bbox_logits )
325
+
326
+ # Concatenate along sequence dimension (dim=1)
327
+ final_lm_logits = torch .cat (all_lm_logits , dim = 1 )
328
+ final_bbox_logits = torch .cat (all_bbox_logits , dim = 1 )
329
+
330
+ return final_lm_logits , final_bbox_logits
331
+
307
332
def forward (
308
333
self ,
309
334
input_ids = None ,
@@ -317,7 +342,6 @@ def forward(
317
342
output_hidden_states = False ,
318
343
output_attentions = False ,
319
344
use_cache = False ,
320
- logits_to_keep = None ,
321
345
encoder_chunk_size = None ,
322
346
cache_idxs = None ,
323
347
num_valid_tokens = None ,
@@ -386,12 +410,9 @@ def forward(
386
410
387
411
hidden_states = outputs .last_hidden_state
388
412
# Only keep the last `logits_to_keep` logits, should bring down memory usage during inference
389
- if logits_to_keep is not None :
390
- hidden_states = hidden_states [:, - logits_to_keep :, :]
391
-
413
+ hidden_states = hidden_states [:, - 1 :, :]
392
414
hidden_states = hidden_states .contiguous ()
393
- bbox_logits = F .sigmoid (self .bbox_head (hidden_states ))
394
- lm_logits = self .lm_head (hidden_states )
415
+ lm_logits , bbox_logits = self .get_logits (hidden_states )
395
416
396
417
return SuryaModelOutput (
397
418
bbox_logits = bbox_logits ,
0 commit comments