Skip to content

Commit e82e05a

Browse files
committed
Pass through more inputs for cache support
1 parent 4fd7598 commit e82e05a

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

surya/common/surya/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def forward(
296296
logits_to_keep=None,
297297
encoder_chunk_size=None,
298298
cache_idxs=None,
299+
valid_tokens=None,
300+
prefill=False,
299301
**kwargs: KwargsForCausalLM,
300302
):
301303
# Process the mixed batch if provided
@@ -351,6 +353,8 @@ def forward(
351353
return_dict=True,
352354
use_cache=use_cache,
353355
cache_idxs=cache_idxs,
356+
valid_tokens=valid_tokens,
357+
prefill=prefill
354358
**kwargs,
355359
)
356360

surya/common/surya/decoder/__init__.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def forward(
161161
past_key_value: Optional[Cache] = None,
162162
cache_position: Optional[torch.LongTensor] = None,
163163
cache_idxs: Optional[List[int]] = None,
164+
valid_tokens: Optional[List[int]] = None,
165+
prefill: bool = False,
164166
**kwargs: Unpack[FlashAttentionKwargs],
165167
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
166168
input_shape = hidden_states.shape[:-1]
@@ -185,7 +187,15 @@ def forward(
185187

186188
if past_key_value is not None:
187189
# sin and cos are specific to RoPE models; cache_position needed for the static cache
188-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "cache_idxs": cache_idxs}
190+
# cache_idxs, valid_tokens, and prefill add support for our new caching mechanism
191+
cache_kwargs = {
192+
"sin": sin,
193+
"cos": cos,
194+
"cache_position": cache_position,
195+
"cache_idxs": cache_idxs,
196+
"valid_tokens": valid_tokens,
197+
"prefill": prefill
198+
}
189199
key_states, value_states = past_key_value.update(
190200
key_states, value_states, self.layer_idx, cache_kwargs
191201
)
@@ -279,6 +289,8 @@ def forward(
279289
use_cache: Optional[bool] = False,
280290
cache_position: Optional[torch.LongTensor] = None,
281291
cache_idxs: Optional[List[int]] = None,
292+
valid_tokens: Optional[List[int]] = None,
293+
prefill: bool = False,
282294
position_embeddings: Optional[
283295
Tuple[torch.Tensor, torch.Tensor]
284296
] = None, # necessary, but kept here for BC
@@ -300,7 +312,9 @@ def forward(
300312
use_cache=use_cache,
301313
cache_position=cache_position,
302314
position_embeddings=position_embeddings,
303-
cache_idxs=cache_idxs
315+
cache_idxs=cache_idxs,
316+
valid_tokens=valid_tokens,
317+
prefill=prefill,
304318
**kwargs,
305319
)
306320
hidden_states = residual + hidden_states
@@ -461,6 +475,8 @@ def forward(
461475
return_dict: Optional[bool] = None,
462476
cache_position: Optional[torch.LongTensor] = None,
463477
cache_idxs: Optional[List[int]] = None,
478+
valid_tokens: Optional[List[int]] = None,
479+
prefill: bool = False,
464480
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
465481
) -> Union[Tuple, BaseModelOutputWithPast]:
466482
use_cache = use_cache if use_cache is not None else self.config.use_cache
@@ -501,6 +517,8 @@ def forward(
501517
cache_position=cache_position,
502518
position_embeddings=position_embeddings,
503519
cache_idxs=cache_idxs,
520+
valid_tokens=valid_tokens,
521+
prefill=prefill,
504522
**flash_attn_kwargs,
505523
)
506524

0 commit comments

Comments
 (0)