Skip to content

Commit 8cc6fef

Browse files
committed
Add benchmark for measuring create_block_mask creation time
ghstack-source-id: e8fdd99 Pull-Request-resolved: #136 ghstack-source-id: e8fdd99 Pull Request resolved: #137
1 parent 6a65742 commit 8cc6fef

File tree

1 file changed

+238
-0
lines changed

1 file changed

+238
-0
lines changed

benchmarks/bench_block_mask.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import itertools
2+
from dataclasses import dataclass
3+
from typing import List, Callable
4+
5+
import torch
6+
from tabulate import tabulate
7+
from tqdm import tqdm
8+
import random
9+
from triton.testing import do_bench
10+
11+
from attn_gym.masks import (
12+
causal_mask,
13+
generate_sliding_window,
14+
generate_prefix_lm_mask,
15+
generate_doc_mask_mod,
16+
generate_dilated_sliding_window,
17+
)
18+
from attn_gym.masks.document_mask import length_to_offsets
19+
20+
from torch.nn.attention.flex_attention import create_block_mask, _mask_mod_signature
21+
22+
device = torch.device("cuda")
23+
24+
# Needed since changing args to function causes recompiles
25+
torch._dynamo.config.cache_size_limit = 1000
26+
27+
28+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
29+
"""Thin wrapper around do_bench_using_profiling"""
30+
no_args = lambda: func(*args, **kwargs)
31+
time = do_bench(no_args)
32+
return time * 1e3
33+
34+
35+
class cuda_memory_usage:
36+
"""Prints the difference CUDA memory usage at the end of a context manager
37+
38+
Args:
39+
log (bool): Whether to print the memory usage to the console
40+
precision (int): The number of decimal places to print
41+
42+
Usage:
43+
```
44+
with cuda_memory_usage() as mem:
45+
# code to profile
46+
print(mem.memory_usage)
47+
```
48+
49+
"""
50+
51+
def __init__(self, log=False, precision=2):
52+
self.log = log
53+
self.precision = precision
54+
self.memory_usage = 0
55+
56+
def __enter__(self):
57+
self.initial_memory = torch.cuda.memory_allocated()
58+
return self
59+
60+
def __exit__(self, exc_type, exc_val, exc_tb):
61+
self.memory_usage = torch.cuda.memory_allocated() - self.initial_memory
62+
if self.log:
63+
memory_usage_gib = self.memory_usage / (1024**3)
64+
print(f"CUDA memory usage: {memory_usage_gib:.{self.precision}f} GiB")
65+
66+
67+
MASK_MOD_MAP = {
68+
"causal": causal_mask,
69+
"sliding_window": generate_sliding_window,
70+
"prefix_lm": generate_prefix_lm_mask,
71+
"doc_mask_mod": generate_doc_mask_mod,
72+
"dilated_sliding_window": generate_dilated_sliding_window,
73+
}
74+
75+
76+
@dataclass(frozen=True)
77+
class ExperimentConfig:
78+
B: int
79+
H: int
80+
M: int
81+
N: int
82+
mask_mod_name: str
83+
dynamic: bool
84+
85+
86+
@dataclass(frozen=True)
87+
class ExperimentResult:
88+
creation_time_ms: float
89+
memory_bytes: int
90+
91+
92+
@dataclass(frozen=True)
93+
class Experiment:
94+
config: ExperimentConfig
95+
result: ExperimentResult
96+
97+
98+
def get_mask_mod(config: ExperimentConfig) -> _mask_mod_signature:
99+
name = config.mask_mod_name
100+
match name:
101+
case "sliding_window":
102+
# Lets have window size be a 1/4
103+
window_size = config.M // 4
104+
return generate_sliding_window(window_size)
105+
case "prefix_lm":
106+
# Same for prefix length
107+
prefix_length = config.M // 4
108+
return generate_prefix_lm_mask(prefix_length)
109+
case "doc_mask_mod":
110+
# Kinda random but at least 2
111+
doc_count = max(2, config.M // 128)
112+
113+
# Generate random lengths that sum to the sequence length
114+
def generate_random_lengths(total_length, num_documents):
115+
# Initialize all lengths to 1 to ensure each document has at least one token
116+
lengths = [1] * num_documents
117+
remaining_length = total_length - num_documents
118+
119+
# Randomly distribute the remaining length
120+
for _ in range(remaining_length):
121+
index = random.randint(0, num_documents - 1)
122+
lengths[index] += 1
123+
124+
return lengths
125+
126+
lengths = generate_random_lengths(config.M, doc_count)
127+
offsets = length_to_offsets(lengths, device)
128+
return generate_doc_mask_mod(causal_mask, offsets)
129+
130+
case "dilated_sliding_window":
131+
window_size = config.M // 4
132+
dilation = 4
133+
return generate_dilated_sliding_window(window_size, dilation)
134+
case _:
135+
mod = MASK_MOD_MAP[name]
136+
return mod
137+
138+
139+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
140+
# Find the mask_mod function by name
141+
assert config.mask_mod_name in MASK_MOD_MAP, f"Mask mod '{config.mask_mod_name}' not found."
142+
mask_mod_fn = get_mask_mod(config)
143+
144+
cbm = torch.compile(create_block_mask, dynamic=config.dynamic)
145+
# Warmup
146+
for _ in range(10):
147+
cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device)
148+
torch.cuda.synchronize(device)
149+
150+
creation_time_us = benchmark_cuda_function_in_microseconds(
151+
lambda: cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device),
152+
)
153+
154+
torch.cuda.synchronize(device)
155+
156+
with cuda_memory_usage() as mem:
157+
cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device)
158+
torch.cuda.synchronize(device)
159+
160+
return ExperimentResult(
161+
creation_time_ms=creation_time_us / 1000, memory_bytes=mem.memory_usage
162+
)
163+
164+
165+
def print_results(experiments: List[Experiment]):
166+
headers = [
167+
"B",
168+
"H",
169+
"M",
170+
"N",
171+
"Mask Mod",
172+
"Dynamic",
173+
"Creation Time (ms)",
174+
"Memory (GiB)",
175+
]
176+
rows = []
177+
for experiment in experiments:
178+
rows.append(
179+
[
180+
experiment.config.B,
181+
experiment.config.H,
182+
experiment.config.M,
183+
experiment.config.N,
184+
experiment.config.mask_mod_name,
185+
experiment.config.dynamic,
186+
f"{experiment.result.creation_time_ms:.4f}",
187+
f"{experiment.result.memory_bytes:.2f}",
188+
]
189+
)
190+
# Sort rows for better readability (e.g., by B, H, M, N)
191+
rows.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
192+
print(tabulate(rows, headers=headers, tablefmt="grid"))
193+
194+
195+
def get_configs() -> List[ExperimentConfig]:
196+
# Define ranges for benchmark parameters
197+
Bs = [1]
198+
Hs = [8]
199+
SeqLens = [8192, 16384, 32768]
200+
# Map string names to mask functions
201+
mask_mods_to_run = list(MASK_MOD_MAP.keys())
202+
dynamic = [
203+
False,
204+
]
205+
206+
configs = []
207+
for B, H, S, mask_mod, dyn in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run, dynamic):
208+
configs.append(
209+
ExperimentConfig(
210+
B=B,
211+
H=H,
212+
M=S, # Assuming M=N for simplicity
213+
N=S,
214+
mask_mod_name=mask_mod,
215+
dynamic=dyn,
216+
)
217+
)
218+
return configs
219+
220+
221+
def main():
222+
torch.random.manual_seed(123)
223+
random.seed(123)
224+
configs = get_configs()
225+
results = []
226+
print(f"Running {len(configs)} benchmark configurations...")
227+
for config in tqdm(configs):
228+
try:
229+
result = run_experiment(config)
230+
results.append(Experiment(config=config, result=result))
231+
except Exception as e:
232+
print(f"Failed to run config {config}: {e}")
233+
234+
print_results(results)
235+
236+
237+
if __name__ == "__main__":
238+
main()

0 commit comments

Comments
 (0)