Skip to content

Commit 084239f

Browse files
committed
Attention mask as a buffer instead of parameter
1 parent 2b4ede1 commit 084239f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

surya/foundation/cache.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def __init__(
2929
self.text_sliding_window = text_sliding_window
3030
self.num_layers = config.num_hidden_layers
3131

32-
# TODO Setup these as buffers since its a nn.Module
33-
self.attention_mask = torch.zeros((self.batch_size, self.max_cache_len), device=device, dtype=torch.int)
32+
self.register_buffer(f"attention_mask", torch.zeros((self.batch_size, self.max_cache_len), device=device, dtype=torch.int))
3433
self.text_token_counts = [torch.zeros(self.batch_size) for _ in range(self.num_layers)]
3534

3635
def _shift_attention_mask_left(self, batch_idx: int, shift_amount: int):

0 commit comments

Comments
 (0)