@@ -180,17 +180,6 @@ def process_outputs(self, outputs: SuryaModelOutput, num_valid_tokens: torch.Ten
180
180
lm_logits = outputs ["lm_logits" ].float () # shape: [B, T, V]
181
181
bbox_logits = outputs ["bbox_logits" ].float () # shape: [B, T, D]
182
182
183
- # token_indices = num_valid_tokens - 1 # shape: [B]
184
- # token_indices = token_indices.view(-1, 1, 1).expand(-1, 1, lm_logits.size(-1)) # shape: [B, 1, V]
185
- # token_indices = token_indices.to(torch.int64) # gather expects int64 for index
186
-
187
- # bbox_indices = num_valid_tokens - 1
188
- # bbox_indices = bbox_indices.view(-1, 1, 1).expand(-1, 1, bbox_logits.size(-1)) # shape: [B, 1, D]
189
- # bbox_indices = bbox_indices.to(torch.int64) # gather expects int64 for index
190
-
191
- # # Gather logits at valid token positions
192
- # next_token_logits = torch.gather(lm_logits, dim=1, index=token_indices) # shape: [B, 1, V]
193
- # next_bbox_logits = torch.gather(bbox_logits, dim=1, index=bbox_indices) # shape: [B, 1, D]
194
183
next_token_logits = lm_logits [:, - 1 :, :]
195
184
next_bbox_logits = bbox_logits [:, - 1 :, :]
196
185
0 commit comments