Skip to content

Commit 5e8adad

Browse files
committed
Update to new foundation model
1 parent 13fd3a5 commit 5e8adad

File tree

6 files changed

+252
-193
lines changed

6 files changed

+252
-193
lines changed

surya/common/surya/__init__.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch.nn.functional as F
88
from transformers import PreTrainedModel
99
from transformers.modeling_outputs import CausalLMOutputWithPast
10+
from transformers.cache_utils import Cache
11+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
1012

1113
from surya.common.s3 import S3DownloaderMixin
1214
from surya.common.surya.config import SuryaModelConfig
@@ -108,6 +110,13 @@ def __init__(
108110
self.bbox_head = nn.Linear(config.hidden_size, 6)
109111
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
110112

113+
if self.config.multi_output_distance is not None and self.config.multi_output_distance > 0:
114+
self.multi_output_embeds = nn.Embedding(
115+
config.max_multi_out,
116+
config.hidden_size,
117+
padding_idx=0,
118+
)
119+
111120
def tie_weights(self):
112121
self._tie_weights()
113122

@@ -279,6 +288,7 @@ def forward(
279288
inputs_embeds=None,
280289
attention_mask=None,
281290
position_ids=None,
291+
cache_position=None,
282292
past_key_values=None,
283293
output_hidden_states=False,
284294
output_attentions=False,
@@ -309,11 +319,33 @@ def forward(
309319
kwargs["cu_seqlens_k"] = cu_seqlens_k
310320
kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k
311321

322+
if cache_position is None:
323+
past_seen_tokens = (
324+
past_key_values.get_seq_length() if past_key_values is not None else 0
325+
)
326+
cache_position = torch.arange(
327+
past_seen_tokens,
328+
past_seen_tokens + inputs_embeds.shape[1],
329+
device=inputs_embeds.device,
330+
)
331+
332+
if position_ids is None:
333+
position_ids = cache_position.unsqueeze(0)
334+
335+
causal_mask = self._update_causal_mask(
336+
attention_mask,
337+
inputs_embeds,
338+
cache_position,
339+
past_key_values,
340+
output_attentions,
341+
)
342+
343+
attention_mask = causal_mask
312344
outputs = self.decoder(
313-
input_ids=None,
314345
inputs_embeds=inputs_embeds,
315346
attention_mask=attention_mask,
316347
position_ids=position_ids,
348+
cache_position=cache_position,
317349
past_key_values=past_key_values,
318350
return_dict=True,
319351
use_cache=use_cache,
@@ -336,3 +368,128 @@ def forward(
336368
attentions=outputs.attentions if output_attentions else None,
337369
past_key_values=outputs.past_key_values,
338370
)
371+
372+
def _update_causal_mask(
373+
self,
374+
attention_mask: torch.Tensor,
375+
input_tensor: torch.Tensor,
376+
cache_position: torch.Tensor,
377+
past_key_values: Cache,
378+
output_attentions: bool,
379+
):
380+
if self.config._attn_implementation == "flash_attention_2":
381+
return attention_mask
382+
383+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
384+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
385+
# to infer the attention mask.
386+
past_seen_tokens = (
387+
past_key_values.get_seq_length() if past_key_values is not None else 0
388+
)
389+
390+
# We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases
391+
dtype, device = input_tensor.dtype, input_tensor.device
392+
min_dtype = torch.finfo(dtype).min
393+
sequence_length = input_tensor.shape[1]
394+
target_length = (
395+
attention_mask.shape[-1]
396+
if isinstance(attention_mask, torch.Tensor)
397+
else past_seen_tokens + sequence_length + 1
398+
)
399+
400+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
401+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
402+
attention_mask,
403+
sequence_length=sequence_length,
404+
target_length=target_length,
405+
dtype=dtype,
406+
device=device,
407+
cache_position=cache_position,
408+
batch_size=input_tensor.shape[0],
409+
config=self.config,
410+
past_key_values=past_key_values,
411+
)
412+
413+
if (
414+
self.config._attn_implementation == "sdpa"
415+
and attention_mask is not None
416+
and attention_mask.device.type in ["cuda", "xpu"]
417+
and not output_attentions
418+
):
419+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
420+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
421+
# Details: https://github.com/pytorch/pytorch/issues/110213
422+
causal_mask = AttentionMaskConverter._unmask_unattended(
423+
causal_mask, min_dtype
424+
)
425+
426+
return causal_mask
427+
428+
@staticmethod
429+
def _prepare_4d_causal_attention_mask_with_cache_position(
430+
attention_mask: torch.Tensor,
431+
sequence_length: int,
432+
target_length: int,
433+
dtype: torch.dtype,
434+
device: torch.device,
435+
cache_position: torch.Tensor,
436+
batch_size: int,
437+
config: SuryaModelConfig,
438+
past_key_values: Cache,
439+
):
440+
"""
441+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
442+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
443+
444+
Args:
445+
attention_mask (`torch.Tensor`):
446+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
447+
sequence_length (`int`):
448+
The sequence length being processed.
449+
target_length (`int`):
450+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
451+
dtype (`torch.dtype`):
452+
The dtype to use for the 4D attention mask.
453+
device (`torch.device`):
454+
The device to plcae the 4D attention mask on.
455+
cache_position (`torch.Tensor`):
456+
Indices depicting the position of the input sequence tokens in the sequence.
457+
batch_size (`torch.Tensor`):
458+
Batch size.
459+
config (`Qwen2Config`):
460+
The model's configuration class
461+
past_key_values (`Cache`):
462+
The cache class that is being used currently to generate
463+
"""
464+
if attention_mask is not None and attention_mask.dim() == 4:
465+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
466+
causal_mask = attention_mask
467+
else:
468+
min_dtype = torch.finfo(dtype).min
469+
causal_mask = torch.full(
470+
(sequence_length, target_length),
471+
fill_value=min_dtype,
472+
dtype=dtype,
473+
device=device,
474+
)
475+
diagonal_attend_mask = torch.arange(
476+
target_length, device=device
477+
) > cache_position.reshape(-1, 1)
478+
# NOTE - Removed sliding window handling here from original impl. since we manage it differently
479+
causal_mask *= diagonal_attend_mask
480+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
481+
if attention_mask is not None:
482+
causal_mask = (
483+
causal_mask.clone()
484+
) # copy to contiguous memory for in-place edit
485+
if attention_mask.shape[-1] > target_length:
486+
attention_mask = attention_mask[:, :target_length]
487+
mask_length = attention_mask.shape[-1]
488+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
489+
:, None, None, :
490+
].to(causal_mask.device)
491+
padding_mask = padding_mask == 0
492+
causal_mask[:, :, :, :mask_length] = causal_mask[
493+
:, :, :, :mask_length
494+
].masked_fill(padding_mask, min_dtype)
495+
return causal_mask

surya/common/surya/config.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional
12
from transformers import PretrainedConfig
23

34
from surya.common.s3 import S3DownloaderMixin
@@ -18,18 +19,24 @@ def __init__(
1819
eos_token_id=1,
1920
pad_token_id=2,
2021
image_token_id=3,
22+
register_token_ids=(4, 5, 6, 7),
23+
eoi_token_id=8,
24+
beacon_token_id=9,
2125
special_token_count=4,
2226
max_sequence_length=1536,
2327
special_ocr_tokens=None,
2428
vision_encoder=None,
2529
decoder=None,
2630
tasks: dict | None = None,
2731
bbox_embed_size: int = 64,
28-
register_token_ids=(4, 5, 6, 7),
29-
unmask_image: bool = False,
3032
num_register_tokens: int = 4,
3133
image_embed_encoding_size: int = 1024,
3234
image_embed_encoding_multiplier: int = 256,
35+
num_beacon_tokens: int = 1,
36+
beacon_token_interval: int = 4096,
37+
sliding_window: Optional[int] = None,
38+
multi_output_distance: int = 4,
39+
max_multi_out: int = 8,
3340
**kwargs,
3441
):
3542
super().__init__(**kwargs)
@@ -41,17 +48,23 @@ def __init__(
4148
self.bos_token_id = bos_token_id
4249
self.eos_token_id = eos_token_id
4350
self.pad_token_id = pad_token_id
51+
self.eoi_token_id = eoi_token_id
52+
self.beacon_token_id = beacon_token_id
4453
self.special_ocr_tokens = special_ocr_tokens
4554
self.special_token_count = special_token_count # pad, bos, etc, tokens
4655
self.max_sequence_length = max_sequence_length
4756
self.tasks = tasks
4857
self.tie_word_embeddings = True
4958
self.bbox_embed_size = bbox_embed_size
50-
self.unmask_image = unmask_image
5159
self.num_register_tokens = num_register_tokens
5260
self.register_token_ids = register_token_ids
5361
self.image_embed_encoding_size = image_embed_encoding_size
5462
self.image_embed_encoding_multiplier = image_embed_encoding_multiplier
63+
self.num_beacon_tokens = num_beacon_tokens
64+
self.beacon_token_interval = beacon_token_interval
65+
self.sliding_window = sliding_window
66+
self.multi_output_distance = multi_output_distance
67+
self.max_multi_out = max_multi_out
5568

5669
if isinstance(vision_encoder, dict):
5770
vision_encoder = SuryaEncoderConfig(**vision_encoder)

0 commit comments

Comments
 (0)