Skip to content

Commit 64c243f

Browse files
committed
Add example for showcasing how to do multi-latent Attention
stack-info: PR: #113, branch: drisspg/stack/6
1 parent 1d8ab59 commit 64c243f

File tree

2 files changed

+536
-0
lines changed

2 files changed

+536
-0
lines changed

attn_gym/mods/latent_attention.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Implementation of Multi-head Level Attention (MLA) RoPE score modification from DeepSeek-V2.
2+
3+
Reference: https://arxiv.org/pdf/2405.04434 - DeepSeek-V2: A Strong, Economical, and
4+
Efficient Mixture-of-Experts Language Model
5+
"""
6+
7+
import torch
8+
from torch import Tensor
9+
from torch.nn.attention.flex_attention import _score_mod_signature
10+
11+
12+
def generate_mla_rope_score_mod(
13+
query_rope: Tensor,
14+
key_rope: Tensor,
15+
num_heads: int,
16+
scale: float = 1.0,
17+
) -> _score_mod_signature:
18+
"""Returns an MLA RoPE score modification function to be used w/ FlexAttention
19+
20+
Args:
21+
query_pe: Positional embeddings for queries [batch, num_heads, seq_len, head_dim]
22+
key_pe: Positional embeddings for keys [batch, num_heads//128, seq_len, head_dim]
23+
num_heads: The number of query heads
24+
scale: Scaling factor for the positional embedding contribution
25+
use_vmap: Whether to use vectorized operations (recommended for training)
26+
27+
Returns:
28+
mla_rope_score_mod: Score modification function for FlexAttention
29+
"""
30+
31+
def mla_rope_score_mod(
32+
score: Tensor, b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor
33+
) -> Tensor:
34+
return score + (
35+
scale * torch.dot(query_rope[b, h, q_idx], key_rope[b, h // num_heads, kv_idx])
36+
)
37+
38+
mla_rope_score_mod.__name__ = f"mla_rope_score_mod_scale_{scale}"
39+
return mla_rope_score_mod
40+
41+
42+
def main(device: str = "cuda"):
43+
"""Visualize the attention scores with MLA RoPE modification.
44+
45+
Args:
46+
device: Device to use for computation
47+
"""
48+
from attn_gym import visualize_attention_scores
49+
50+
# Example dimensions
51+
B, H, SEQ_LEN, LATENT_HEAD_DIM = 1, 128, 8, 512
52+
ROPE_HEAD_DIM = 64
53+
54+
# Create random tensors for visualization
55+
query = torch.rand(B, H, SEQ_LEN, LATENT_HEAD_DIM, device=device)
56+
57+
key = torch.rand(B, 1, SEQ_LEN, LATENT_HEAD_DIM, device=device)
58+
59+
# Create positional embeddings
60+
query_pe = torch.rand(B, H, SEQ_LEN, ROPE_HEAD_DIM, device=device)
61+
key_pe = torch.rand(B, 1, SEQ_LEN, ROPE_HEAD_DIM, device=device)
62+
63+
# Generate the score modification function
64+
mla_rope_score_mod = generate_mla_rope_score_mod(
65+
query_rope=query_pe, key_rope=key_pe, num_heads=H
66+
)
67+
68+
# Visualize attention scores with MLA RoPE modification
69+
visualize_attention_scores(
70+
query, key, score_mod=mla_rope_score_mod, device=device, name="mla_rope_score_mod"
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
try:
76+
from jsonargparse import CLI
77+
except ImportError:
78+
raise ImportError("Be sure to run: pip install -e .'[viz]'")
79+
CLI(main)

0 commit comments

Comments
 (0)