7
7
import torch .nn .functional as F
8
8
from transformers import PreTrainedModel
9
9
from transformers .modeling_outputs import CausalLMOutputWithPast
10
+ from transformers .cache_utils import Cache
11
+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
10
12
11
13
from surya .common .s3 import S3DownloaderMixin
12
14
from surya .common .surya .config import SuryaModelConfig
@@ -108,6 +110,13 @@ def __init__(
108
110
self .bbox_head = nn .Linear (config .hidden_size , 6 )
109
111
self .lm_head = nn .Linear (config .hidden_size , config .vocab_size )
110
112
113
+ if self .config .multi_output_distance is not None and self .config .multi_output_distance > 0 :
114
+ self .multi_output_embeds = nn .Embedding (
115
+ config .max_multi_out ,
116
+ config .hidden_size ,
117
+ padding_idx = 0 ,
118
+ )
119
+
111
120
def tie_weights (self ):
112
121
self ._tie_weights ()
113
122
@@ -279,6 +288,7 @@ def forward(
279
288
inputs_embeds = None ,
280
289
attention_mask = None ,
281
290
position_ids = None ,
291
+ cache_position = None ,
282
292
past_key_values = None ,
283
293
output_hidden_states = False ,
284
294
output_attentions = False ,
@@ -309,11 +319,33 @@ def forward(
309
319
kwargs ["cu_seqlens_k" ] = cu_seqlens_k
310
320
kwargs ["max_seqlen_in_batch_k" ] = max_seqlen_in_batch_k
311
321
322
+ if cache_position is None :
323
+ past_seen_tokens = (
324
+ past_key_values .get_seq_length () if past_key_values is not None else 0
325
+ )
326
+ cache_position = torch .arange (
327
+ past_seen_tokens ,
328
+ past_seen_tokens + inputs_embeds .shape [1 ],
329
+ device = inputs_embeds .device ,
330
+ )
331
+
332
+ if position_ids is None :
333
+ position_ids = cache_position .unsqueeze (0 )
334
+
335
+ causal_mask = self ._update_causal_mask (
336
+ attention_mask ,
337
+ inputs_embeds ,
338
+ cache_position ,
339
+ past_key_values ,
340
+ output_attentions ,
341
+ )
342
+
343
+ attention_mask = causal_mask
312
344
outputs = self .decoder (
313
- input_ids = None ,
314
345
inputs_embeds = inputs_embeds ,
315
346
attention_mask = attention_mask ,
316
347
position_ids = position_ids ,
348
+ cache_position = cache_position ,
317
349
past_key_values = past_key_values ,
318
350
return_dict = True ,
319
351
use_cache = use_cache ,
@@ -336,3 +368,128 @@ def forward(
336
368
attentions = outputs .attentions if output_attentions else None ,
337
369
past_key_values = outputs .past_key_values ,
338
370
)
371
+
372
+ def _update_causal_mask (
373
+ self ,
374
+ attention_mask : torch .Tensor ,
375
+ input_tensor : torch .Tensor ,
376
+ cache_position : torch .Tensor ,
377
+ past_key_values : Cache ,
378
+ output_attentions : bool ,
379
+ ):
380
+ if self .config ._attn_implementation == "flash_attention_2" :
381
+ return attention_mask
382
+
383
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
384
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
385
+ # to infer the attention mask.
386
+ past_seen_tokens = (
387
+ past_key_values .get_seq_length () if past_key_values is not None else 0
388
+ )
389
+
390
+ # We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases
391
+ dtype , device = input_tensor .dtype , input_tensor .device
392
+ min_dtype = torch .finfo (dtype ).min
393
+ sequence_length = input_tensor .shape [1 ]
394
+ target_length = (
395
+ attention_mask .shape [- 1 ]
396
+ if isinstance (attention_mask , torch .Tensor )
397
+ else past_seen_tokens + sequence_length + 1
398
+ )
399
+
400
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
401
+ causal_mask = self ._prepare_4d_causal_attention_mask_with_cache_position (
402
+ attention_mask ,
403
+ sequence_length = sequence_length ,
404
+ target_length = target_length ,
405
+ dtype = dtype ,
406
+ device = device ,
407
+ cache_position = cache_position ,
408
+ batch_size = input_tensor .shape [0 ],
409
+ config = self .config ,
410
+ past_key_values = past_key_values ,
411
+ )
412
+
413
+ if (
414
+ self .config ._attn_implementation == "sdpa"
415
+ and attention_mask is not None
416
+ and attention_mask .device .type in ["cuda" , "xpu" ]
417
+ and not output_attentions
418
+ ):
419
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
420
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
421
+ # Details: https://github.com/pytorch/pytorch/issues/110213
422
+ causal_mask = AttentionMaskConverter ._unmask_unattended (
423
+ causal_mask , min_dtype
424
+ )
425
+
426
+ return causal_mask
427
+
428
+ @staticmethod
429
+ def _prepare_4d_causal_attention_mask_with_cache_position (
430
+ attention_mask : torch .Tensor ,
431
+ sequence_length : int ,
432
+ target_length : int ,
433
+ dtype : torch .dtype ,
434
+ device : torch .device ,
435
+ cache_position : torch .Tensor ,
436
+ batch_size : int ,
437
+ config : SuryaModelConfig ,
438
+ past_key_values : Cache ,
439
+ ):
440
+ """
441
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
442
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
443
+
444
+ Args:
445
+ attention_mask (`torch.Tensor`):
446
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
447
+ sequence_length (`int`):
448
+ The sequence length being processed.
449
+ target_length (`int`):
450
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
451
+ dtype (`torch.dtype`):
452
+ The dtype to use for the 4D attention mask.
453
+ device (`torch.device`):
454
+ The device to plcae the 4D attention mask on.
455
+ cache_position (`torch.Tensor`):
456
+ Indices depicting the position of the input sequence tokens in the sequence.
457
+ batch_size (`torch.Tensor`):
458
+ Batch size.
459
+ config (`Qwen2Config`):
460
+ The model's configuration class
461
+ past_key_values (`Cache`):
462
+ The cache class that is being used currently to generate
463
+ """
464
+ if attention_mask is not None and attention_mask .dim () == 4 :
465
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
466
+ causal_mask = attention_mask
467
+ else :
468
+ min_dtype = torch .finfo (dtype ).min
469
+ causal_mask = torch .full (
470
+ (sequence_length , target_length ),
471
+ fill_value = min_dtype ,
472
+ dtype = dtype ,
473
+ device = device ,
474
+ )
475
+ diagonal_attend_mask = torch .arange (
476
+ target_length , device = device
477
+ ) > cache_position .reshape (- 1 , 1 )
478
+ # NOTE - Removed sliding window handling here from original impl. since we manage it differently
479
+ causal_mask *= diagonal_attend_mask
480
+ causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
481
+ if attention_mask is not None :
482
+ causal_mask = (
483
+ causal_mask .clone ()
484
+ ) # copy to contiguous memory for in-place edit
485
+ if attention_mask .shape [- 1 ] > target_length :
486
+ attention_mask = attention_mask [:, :target_length ]
487
+ mask_length = attention_mask .shape [- 1 ]
488
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [
489
+ :, None , None , :
490
+ ].to (causal_mask .device )
491
+ padding_mask = padding_mask == 0
492
+ causal_mask [:, :, :, :mask_length ] = causal_mask [
493
+ :, :, :, :mask_length
494
+ ].masked_fill (padding_mask , min_dtype )
495
+ return causal_mask
0 commit comments