Skip to content

Commit 43d2042

Browse files
committed
Add benchmark for measuring create_block_mask creation time
ghstack-source-id: 3165f82 Pull-Request-resolved: #136
1 parent 6a65742 commit 43d2042

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed

benchmarks/bench_block_mask.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import itertools
2+
from dataclasses import dataclass
3+
from typing import List, Callable
4+
5+
import torch
6+
from tabulate import tabulate
7+
from tqdm import tqdm
8+
9+
try:
10+
import transformer_nuggets
11+
except ImportError:
12+
print(
13+
"Need to install transformer_nuggets for this benchmark. "
14+
"Run `pip install git+https://github.com/drisspg/transformer_nuggets`"
15+
)
16+
# Exit if the dependency is missing
17+
sys.exit(1)
18+
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, profiler, cuda_memory_usage
19+
20+
21+
from attn_gym.masks import causal_mask, generate_sliding_window, generate_prefix_lm_mask, generate_doc_mask_mod, generate_dilated_sliding_window
22+
23+
from torch.nn.attention.flex_attention import create_block_mask, _mask_mod_signature
24+
import sys
25+
26+
device = torch.device("cuda")
27+
28+
# Needed since changing args to function causes recompiles
29+
torch._dynamo.config.cache_size_limit = 1000
30+
31+
32+
MASK_MOD_MAP = {
33+
"causal": causal_mask,
34+
"sliding_window": generate_sliding_window,
35+
"prefix_lm": generate_prefix_lm_mask,
36+
"doc_mask_mod": generate_doc_mask_mod,
37+
"dilated_sliding_window": generate_dilated_sliding_window,
38+
}
39+
40+
@dataclass(frozen=True)
41+
class ExperimentConfig:
42+
B: int
43+
H: int
44+
M: int
45+
N: int
46+
mask_mod_name: str
47+
48+
49+
@dataclass(frozen=True)
50+
class ExperimentResult:
51+
creation_time_ms: float
52+
memory_bytes: int
53+
54+
55+
@dataclass(frozen=True)
56+
class Experiment:
57+
config: ExperimentConfig
58+
result: ExperimentResult
59+
60+
61+
def get_mask_mod(name: str) -> _mask_mod_signature:
62+
match name:
63+
case "sliding_window":
64+
return generate_sliding_window()
65+
case "prefix_lm":
66+
return generate_prefix_lm_mask()
67+
case "doc_mask_mod":
68+
return generate_doc_mask_mod()
69+
case "dilated_sliding_window":
70+
return generate_dilated_sliding_window()
71+
case _:
72+
mod = MASK_MOD_MAP[name]
73+
return mod
74+
75+
76+
def get_configs() -> List[ExperimentConfig]:
77+
# Define ranges for benchmark parameters
78+
Bs = [1, 4, 8]
79+
Hs = [8, 16]
80+
# Sequence lengths - adjust as needed
81+
# Using powers of 2 up to a reasonable limit for mask creation
82+
SeqLens = [1024, 2048, 4096, 8192]
83+
# Map string names to mask functions
84+
mask_mods_to_run = list(MASK_MOD_MAP.keys())
85+
86+
configs = []
87+
for B, H, S, mask_mod in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run):
88+
configs.append(
89+
ExperimentConfig(
90+
B=B,
91+
H=H,
92+
M=S, # Assuming M=N for simplicity
93+
N=S,
94+
mask_mod_name=mask_mod
95+
)
96+
)
97+
return configs
98+
99+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
100+
# Find the mask_mod function by name
101+
assert config.mask_mod_name in MASK_MOD_MAP, f"Mask mod '{config.mask_mod_name}' not found."
102+
mask_mod_fn = get_mask_mod(config.mask_mod_name)
103+
104+
105+
# --- Time Benchmarking ---
106+
cbm = torch.compile(create_block_mask)
107+
# Warmup
108+
for _ in range(10):
109+
cbm(
110+
mask_mod_fn, config.B, config.H, config.M, config.N, device=device
111+
)
112+
torch.cuda.synchronize(device)
113+
114+
creation_time_ms = benchmark_cuda_function_in_microseconds(
115+
lambda: cbm(
116+
mask_mod_fn, config.B, config.H, config.M, config.N, device=device
117+
),
118+
)
119+
120+
torch.cuda.synchronize(device)
121+
122+
with cuda_memory_usage() as memory_bytes:
123+
cbm(
124+
mask_mod_fn, config.B, config.H, config.M, config.N, device=device
125+
)
126+
127+
128+
return ExperimentResult(
129+
creation_time_ms=creation_time_ms * 1000,
130+
memory_bytes=memory_bytes #
131+
)
132+
133+
134+
def print_results(experiments: List[Experiment]):
135+
headers = [
136+
"B",
137+
"H",
138+
"M",
139+
"N",
140+
"Mask Mod",
141+
"Creation Time (ms)",
142+
"Memory (GiB)",
143+
]
144+
rows = []
145+
for experiment in experiments:
146+
rows.append(
147+
[
148+
experiment.config.B,
149+
experiment.config.H,
150+
experiment.config.M,
151+
experiment.config.N,
152+
experiment.config.mask_mod_name,
153+
f"{experiment.result.creation_time_ms:.4f}",
154+
f"{experiment.result.memory_bytes:.2f}"
155+
]
156+
)
157+
# Sort rows for better readability (e.g., by B, H, M, N)
158+
rows.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
159+
print(tabulate(rows, headers=headers, tablefmt="grid"))
160+
161+
162+
def main():
163+
torch.random.manual_seed(123)
164+
configs = get_configs()
165+
results = []
166+
print(f"Running {len(configs)} benchmark configurations...")
167+
for config in tqdm(configs):
168+
try:
169+
result = run_experiment(config)
170+
results.append(Experiment(config=config, result=result))
171+
except Exception as e:
172+
print(f"Failed to run config {config}: {e}")
173+
# Optionally skip failed configs or handle differently
174+
175+
# Use Tabulate to print results
176+
print_results(results)
177+
178+
179+
if __name__ == "__main__":
180+
main()
181+

0 commit comments

Comments
 (0)