Skip to content

Commit 0bc1197

Browse files
committed
add batchify mask-mod
1 parent 001b36d commit 0bc1197

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.vscode
2+
.DS_Store
23

34
# Byte-compiled / optimized / DLL files
45
__pycache__/

attn_gym/masks/batchify.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Batches input tokens into groups. Attention is only allowed within the same group."""
2+
3+
import torch
4+
from torch import Tensor
5+
from torch.nn.attention.flex_attention import _mask_mod_signature, noop_mask
6+
from attn_gym.masks import causal_mask
7+
8+
9+
def batchify_mask_mod(mask_mod: _mask_mod_signature, batchify_size: int) -> _mask_mod_signature:
10+
"""Given arbirary mask_mod, batchify it to only allow attention within the same batch.
11+
12+
Args:
13+
mask_mod: The mask mod to apply to the documents
14+
batch_size: The number of tokens in each batch.
15+
"""
16+
17+
def batched_mask_mod(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor):
18+
# Get the batch index of the query and key
19+
q_batch = q_idx // batchify_size
20+
kv_batch = kv_idx // batchify_size
21+
22+
# Only allow attention within the same batch
23+
same_batch = q_batch == kv_batch
24+
25+
# Apply the original mask mod
26+
inner_mask = mask_mod(b, h, q_idx % batchify_size, kv_idx % batchify_size)
27+
28+
return same_batch & inner_mask
29+
30+
batched_mask_mod.__name__ = f"batched_mask_mod_{mask_mod.__name__}_batch_size_{batchify_size}"
31+
return batched_mask_mod
32+
33+
34+
def main(device: str = "cpu", causal: bool = False):
35+
"""Visualize the attention scores of document causal mask mod.
36+
37+
Args:
38+
device (str): Device to use for computation. Defaults to "cpu".
39+
"""
40+
from attn_gym import visualize_attention_scores
41+
import random
42+
43+
random.seed(0)
44+
45+
seq_len, batchify_size = 12, 4
46+
B, H, SEQ_LEN, HEAD_DIM = 1, 1, seq_len, 8
47+
48+
def make_tensor():
49+
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)
50+
51+
query, key = make_tensor(), make_tensor()
52+
if causal:
53+
base_mask_mod = causal_mask
54+
else:
55+
base_mask_mod = noop_mask
56+
57+
batched_mask_mod = batchify_mask_mod(base_mask_mod, batchify_size)
58+
59+
visualize_attention_scores(
60+
query,
61+
key,
62+
mask_mod=batched_mask_mod,
63+
device=device,
64+
name="batchify mask_mod",
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
try:
70+
from jsonargparse import CLI
71+
except ImportError:
72+
raise ImportError("Be sure to run: pip install -e .[viz]")
73+
74+
CLI(main)

0 commit comments

Comments
 (0)