|
| 1 | +"""Batches input tokens into groups. Attention is only allowed within the same group.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import Tensor |
| 5 | +from torch.nn.attention.flex_attention import _mask_mod_signature, noop_mask |
| 6 | +from attn_gym.masks import causal_mask |
| 7 | + |
| 8 | + |
| 9 | +def batchify_mask_mod(mask_mod: _mask_mod_signature, batchify_size: int) -> _mask_mod_signature: |
| 10 | + """Given arbirary mask_mod, batchify it to only allow attention within the same batch. |
| 11 | +
|
| 12 | + Args: |
| 13 | + mask_mod: The mask mod to apply to the documents |
| 14 | + batch_size: The number of tokens in each batch. |
| 15 | + """ |
| 16 | + |
| 17 | + def batched_mask_mod(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor): |
| 18 | + # Get the batch index of the query and key |
| 19 | + q_batch = q_idx // batchify_size |
| 20 | + kv_batch = kv_idx // batchify_size |
| 21 | + |
| 22 | + # Only allow attention within the same batch |
| 23 | + same_batch = q_batch == kv_batch |
| 24 | + |
| 25 | + # Apply the original mask mod |
| 26 | + inner_mask = mask_mod(b, h, q_idx % batchify_size, kv_idx % batchify_size) |
| 27 | + |
| 28 | + return same_batch & inner_mask |
| 29 | + |
| 30 | + batched_mask_mod.__name__ = f"batched_mask_mod_{mask_mod.__name__}_batch_size_{batchify_size}" |
| 31 | + return batched_mask_mod |
| 32 | + |
| 33 | + |
| 34 | +def main(device: str = "cpu", causal: bool = False): |
| 35 | + """Visualize the attention scores of document causal mask mod. |
| 36 | +
|
| 37 | + Args: |
| 38 | + device (str): Device to use for computation. Defaults to "cpu". |
| 39 | + """ |
| 40 | + from attn_gym import visualize_attention_scores |
| 41 | + import random |
| 42 | + |
| 43 | + random.seed(0) |
| 44 | + |
| 45 | + seq_len, batchify_size = 12, 4 |
| 46 | + B, H, SEQ_LEN, HEAD_DIM = 1, 1, seq_len, 8 |
| 47 | + |
| 48 | + def make_tensor(): |
| 49 | + return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device) |
| 50 | + |
| 51 | + query, key = make_tensor(), make_tensor() |
| 52 | + if causal: |
| 53 | + base_mask_mod = causal_mask |
| 54 | + else: |
| 55 | + base_mask_mod = noop_mask |
| 56 | + |
| 57 | + batched_mask_mod = batchify_mask_mod(base_mask_mod, batchify_size) |
| 58 | + |
| 59 | + visualize_attention_scores( |
| 60 | + query, |
| 61 | + key, |
| 62 | + mask_mod=batched_mask_mod, |
| 63 | + device=device, |
| 64 | + name="batchify mask_mod", |
| 65 | + ) |
| 66 | + |
| 67 | + |
| 68 | +if __name__ == "__main__": |
| 69 | + try: |
| 70 | + from jsonargparse import CLI |
| 71 | + except ImportError: |
| 72 | + raise ImportError("Be sure to run: pip install -e .[viz]") |
| 73 | + |
| 74 | + CLI(main) |
0 commit comments