|
| 1 | +"""Implementation of Multi-head Level Attention (MLA) RoPE score modification from DeepSeek-V2. |
| 2 | +
|
| 3 | +Reference: https://arxiv.org/pdf/2405.04434 - DeepSeek-V2: A Strong, Economical, and |
| 4 | +Efficient Mixture-of-Experts Language Model |
| 5 | +""" |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch import Tensor |
| 9 | +from torch.nn.attention.flex_attention import _score_mod_signature |
| 10 | + |
| 11 | + |
| 12 | +def generate_mla_rope_score_mod( |
| 13 | + query_rope: Tensor, |
| 14 | + key_rope: Tensor, |
| 15 | + num_heads: int, |
| 16 | + scale: float = 1.0, |
| 17 | +) -> _score_mod_signature: |
| 18 | + """Returns an MLA RoPE score modification function to be used w/ FlexAttention |
| 19 | +
|
| 20 | + Args: |
| 21 | + query_pe: Positional embeddings for queries [batch, num_heads, seq_len, head_dim] |
| 22 | + key_pe: Positional embeddings for keys [batch, num_heads//128, seq_len, head_dim] |
| 23 | + num_heads: The number of query heads |
| 24 | + scale: Scaling factor for the positional embedding contribution |
| 25 | + use_vmap: Whether to use vectorized operations (recommended for training) |
| 26 | +
|
| 27 | + Returns: |
| 28 | + mla_rope_score_mod: Score modification function for FlexAttention |
| 29 | + """ |
| 30 | + |
| 31 | + def mla_rope_score_mod( |
| 32 | + score: Tensor, b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor |
| 33 | + ) -> Tensor: |
| 34 | + return score + ( |
| 35 | + scale * torch.dot(query_rope[b, h, q_idx], key_rope[b, h // num_heads, kv_idx]) |
| 36 | + ) |
| 37 | + |
| 38 | + mla_rope_score_mod.__name__ = f"mla_rope_score_mod_scale_{scale}" |
| 39 | + return mla_rope_score_mod |
| 40 | + |
| 41 | + |
| 42 | +def main(device: str = "cuda"): |
| 43 | + """Visualize the attention scores with MLA RoPE modification. |
| 44 | +
|
| 45 | + Args: |
| 46 | + device: Device to use for computation |
| 47 | + """ |
| 48 | + from attn_gym import visualize_attention_scores |
| 49 | + |
| 50 | + # Example dimensions |
| 51 | + B, H, SEQ_LEN, LATENT_HEAD_DIM = 1, 128, 8, 512 |
| 52 | + ROPE_HEAD_DIM = 64 |
| 53 | + |
| 54 | + # Create random tensors for visualization |
| 55 | + query = torch.rand(B, H, SEQ_LEN, LATENT_HEAD_DIM, device=device) |
| 56 | + |
| 57 | + key = torch.rand(B, 1, SEQ_LEN, LATENT_HEAD_DIM, device=device) |
| 58 | + |
| 59 | + # Create positional embeddings |
| 60 | + query_pe = torch.rand(B, H, SEQ_LEN, ROPE_HEAD_DIM, device=device) |
| 61 | + key_pe = torch.rand(B, 1, SEQ_LEN, ROPE_HEAD_DIM, device=device) |
| 62 | + |
| 63 | + # Generate the score modification function |
| 64 | + mla_rope_score_mod = generate_mla_rope_score_mod( |
| 65 | + query_rope=query_pe, key_rope=key_pe, num_heads=H |
| 66 | + ) |
| 67 | + |
| 68 | + # Visualize attention scores with MLA RoPE modification |
| 69 | + visualize_attention_scores( |
| 70 | + query, key, score_mod=mla_rope_score_mod, device=device, name="mla_rope_score_mod" |
| 71 | + ) |
| 72 | + |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + try: |
| 76 | + from jsonargparse import CLI |
| 77 | + except ImportError: |
| 78 | + raise ImportError("Be sure to run: pip install -e .'[viz]'") |
| 79 | + CLI(main) |
0 commit comments