Skip to content

Commit 3a489a9

Browse files
committed
Add benchmark for measuring create_block_mask creation time
ghstack-source-id: ac76af6 Pull-Request-resolved: #136 ghstack-source-id: ac76af6 Pull Request resolved: #137
1 parent 6a65742 commit 3a489a9

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed

benchmarks/bench_block_mask.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import itertools
2+
from dataclasses import dataclass
3+
from typing import List
4+
5+
import torch
6+
from tabulate import tabulate
7+
from tqdm import tqdm
8+
import random
9+
import sys
10+
from importlib.util import find_spec
11+
12+
# Check if transformer_nuggets is available
13+
if find_spec("transformer_nuggets") is None:
14+
print(
15+
"Need to install transformer_nuggets for this benchmark. "
16+
"Run `pip install git+https://github.com/drisspg/transformer_nuggets`"
17+
)
18+
sys.exit(1)
19+
20+
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, cuda_memory_usage
21+
22+
23+
from attn_gym.masks import (
24+
causal_mask,
25+
generate_sliding_window,
26+
generate_prefix_lm_mask,
27+
generate_doc_mask_mod,
28+
generate_dilated_sliding_window,
29+
)
30+
from attn_gym.masks.document_mask import length_to_offsets
31+
32+
from torch.nn.attention.flex_attention import create_block_mask, _mask_mod_signature
33+
34+
device = torch.device("cuda")
35+
36+
# Needed since changing args to function causes recompiles
37+
torch._dynamo.config.cache_size_limit = 1000
38+
39+
40+
MASK_MOD_MAP = {
41+
"causal": causal_mask,
42+
"sliding_window": generate_sliding_window,
43+
"prefix_lm": generate_prefix_lm_mask,
44+
"doc_mask_mod": generate_doc_mask_mod,
45+
"dilated_sliding_window": generate_dilated_sliding_window,
46+
}
47+
48+
49+
@dataclass(frozen=True)
50+
class ExperimentConfig:
51+
B: int
52+
H: int
53+
M: int
54+
N: int
55+
mask_mod_name: str
56+
57+
58+
@dataclass(frozen=True)
59+
class ExperimentResult:
60+
creation_time_ms: float
61+
memory_bytes: int
62+
63+
64+
@dataclass(frozen=True)
65+
class Experiment:
66+
config: ExperimentConfig
67+
result: ExperimentResult
68+
69+
70+
def get_mask_mod(config: ExperimentConfig) -> _mask_mod_signature:
71+
name = config.mask_mod_name
72+
match name:
73+
case "sliding_window":
74+
# Lets have window size be a 1/4
75+
window_size = config.M // 4
76+
return generate_sliding_window(window_size)
77+
case "prefix_lm":
78+
# Same for prefix length
79+
prefix_length = config.M // 4
80+
return generate_prefix_lm_mask(prefix_length)
81+
case "doc_mask_mod":
82+
# Kinda random but at least 2
83+
doc_count = max(2, config.M // 128)
84+
85+
# Generate random lengths that sum to the sequence length
86+
def generate_random_lengths(total_length, num_documents):
87+
# Initialize all lengths to 1 to ensure each document has at least one token
88+
lengths = [1] * num_documents
89+
remaining_length = total_length - num_documents
90+
91+
# Randomly distribute the remaining length
92+
for _ in range(remaining_length):
93+
index = random.randint(0, num_documents - 1)
94+
lengths[index] += 1
95+
96+
return lengths
97+
98+
lengths = generate_random_lengths(config.M, doc_count)
99+
offsets = length_to_offsets(lengths, device)
100+
return generate_doc_mask_mod(causal_mask, offsets)
101+
102+
case "dilated_sliding_window":
103+
window_size = config.M // 4
104+
dilation = 4
105+
return generate_dilated_sliding_window(window_size, dilation)
106+
case _:
107+
mod = MASK_MOD_MAP[name]
108+
return mod
109+
110+
111+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
112+
# Find the mask_mod function by name
113+
assert config.mask_mod_name in MASK_MOD_MAP, f"Mask mod '{config.mask_mod_name}' not found."
114+
mask_mod_fn = get_mask_mod(config)
115+
116+
cbm = torch.compile(create_block_mask, dynamic=False)
117+
# Warmup
118+
for _ in range(10):
119+
cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device)
120+
torch.cuda.synchronize(device)
121+
122+
creation_time_us = benchmark_cuda_function_in_microseconds(
123+
lambda: cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device),
124+
)
125+
126+
torch.cuda.synchronize(device)
127+
128+
with cuda_memory_usage() as mem:
129+
cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device)
130+
torch.cuda.synchronize(device)
131+
132+
return ExperimentResult(
133+
creation_time_ms=creation_time_us / 1000, memory_bytes=mem.memory_usage
134+
)
135+
136+
137+
def print_results(experiments: List[Experiment]):
138+
headers = [
139+
"B",
140+
"H",
141+
"M",
142+
"N",
143+
"Mask Mod",
144+
"Creation Time (ms)",
145+
"Memory (GiB)",
146+
]
147+
rows = []
148+
for experiment in experiments:
149+
rows.append(
150+
[
151+
experiment.config.B,
152+
experiment.config.H,
153+
experiment.config.M,
154+
experiment.config.N,
155+
experiment.config.mask_mod_name,
156+
f"{experiment.result.creation_time_ms:.4f}",
157+
f"{experiment.result.memory_bytes:.2f}",
158+
]
159+
)
160+
# Sort rows for better readability (e.g., by B, H, M, N)
161+
rows.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
162+
print(tabulate(rows, headers=headers, tablefmt="grid"))
163+
164+
165+
def get_configs() -> List[ExperimentConfig]:
166+
# Define ranges for benchmark parameters
167+
Bs = [1]
168+
Hs = [8]
169+
SeqLens = [8192, 16384, 32768]
170+
# Map string names to mask functions
171+
mask_mods_to_run = list(MASK_MOD_MAP.keys())
172+
173+
configs = []
174+
for B, H, S, mask_mod in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run):
175+
configs.append(
176+
ExperimentConfig(
177+
B=B,
178+
H=H,
179+
M=S, # Assuming M=N for simplicity
180+
N=S,
181+
mask_mod_name=mask_mod,
182+
)
183+
)
184+
return configs
185+
186+
187+
def main():
188+
torch.random.manual_seed(123)
189+
configs = get_configs()
190+
results = []
191+
print(f"Running {len(configs)} benchmark configurations...")
192+
for config in tqdm(configs):
193+
try:
194+
result = run_experiment(config)
195+
results.append(Experiment(config=config, result=result))
196+
except Exception as e:
197+
print(f"Failed to run config {config}: {e}")
198+
199+
print_results(results)
200+
201+
202+
if __name__ == "__main__":
203+
main()

0 commit comments

Comments
 (0)