Skip to content

Commit 660deb0

Browse files
committed
Cleanup
1 parent 6576449 commit 660deb0

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

surya/foundation/__init__.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,6 @@ def process_outputs(self, outputs: SuryaModelOutput, num_valid_tokens: torch.Ten
180180
lm_logits = outputs["lm_logits"].float() # shape: [B, T, V]
181181
bbox_logits = outputs["bbox_logits"].float() # shape: [B, T, D]
182182

183-
# token_indices = num_valid_tokens - 1 # shape: [B]
184-
# token_indices = token_indices.view(-1, 1, 1).expand(-1, 1, lm_logits.size(-1)) # shape: [B, 1, V]
185-
# token_indices = token_indices.to(torch.int64) # gather expects int64 for index
186-
187-
# bbox_indices = num_valid_tokens - 1
188-
# bbox_indices = bbox_indices.view(-1, 1, 1).expand(-1, 1, bbox_logits.size(-1)) # shape: [B, 1, D]
189-
# bbox_indices = bbox_indices.to(torch.int64) # gather expects int64 for index
190-
191-
# # Gather logits at valid token positions
192-
# next_token_logits = torch.gather(lm_logits, dim=1, index=token_indices) # shape: [B, 1, V]
193-
# next_bbox_logits = torch.gather(bbox_logits, dim=1, index=bbox_indices) # shape: [B, 1, D]
194183
next_token_logits = lm_logits[:, -1:, :]
195184
next_bbox_logits = bbox_logits[:, -1:, :]
196185

0 commit comments

Comments
 (0)