Skip to content

Commit dcb69b4

Browse files
committed
Add benchmark for measuring create_block_mask creation time
ghstack-source-id: 3165f82 Pull-Request-resolved: #136 ghstack-source-id: 08fd434 Pull Request resolved: #137
1 parent 6a65742 commit dcb69b4

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed

benchmarks/bench_block_mask.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import itertools
2+
from dataclasses import dataclass
3+
from typing import List
4+
5+
import torch
6+
from tabulate import tabulate
7+
from tqdm import tqdm
8+
import random
9+
import sys
10+
from importlib.util import find_spec
11+
12+
# Check if transformer_nuggets is available
13+
if find_spec("transformer_nuggets") is None:
14+
print(
15+
"Need to install transformer_nuggets for this benchmark. "
16+
"Run `pip install git+https://github.com/drisspg/transformer_nuggets`"
17+
)
18+
sys.exit(1)
19+
20+
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, cuda_memory_usage
21+
22+
23+
from attn_gym.masks import (
24+
causal_mask,
25+
generate_sliding_window,
26+
generate_prefix_lm_mask,
27+
generate_doc_mask_mod,
28+
generate_dilated_sliding_window,
29+
)
30+
from attn_gym.masks.document_mask import length_to_offsets
31+
32+
from torch.nn.attention.flex_attention import create_block_mask, _mask_mod_signature
33+
34+
device = torch.device("cuda")
35+
36+
# Needed since changing args to function causes recompiles
37+
torch._dynamo.config.cache_size_limit = 1000
38+
39+
40+
MASK_MOD_MAP = {
41+
"causal": causal_mask,
42+
"sliding_window": generate_sliding_window,
43+
"prefix_lm": generate_prefix_lm_mask,
44+
"doc_mask_mod": generate_doc_mask_mod,
45+
"dilated_sliding_window": generate_dilated_sliding_window,
46+
}
47+
48+
49+
@dataclass(frozen=True)
50+
class ExperimentConfig:
51+
B: int
52+
H: int
53+
M: int
54+
N: int
55+
mask_mod_name: str
56+
57+
58+
@dataclass(frozen=True)
59+
class ExperimentResult:
60+
creation_time_ms: float
61+
memory_bytes: int
62+
63+
64+
@dataclass(frozen=True)
65+
class Experiment:
66+
config: ExperimentConfig
67+
result: ExperimentResult
68+
69+
70+
def get_mask_mod(config: ExperimentConfig) -> _mask_mod_signature:
71+
name = config.mask_mod_name
72+
match name:
73+
case "sliding_window":
74+
# Lets have window size be a 1/4
75+
window_size = config.M // 4
76+
return generate_sliding_window(window_size)
77+
case "prefix_lm":
78+
# Same for prefix length
79+
prefix_length = config.M // 4
80+
return generate_prefix_lm_mask(prefix_length)
81+
case "doc_mask_mod":
82+
# Kinda random but at least 2
83+
doc_count = max(2, config.M // 128)
84+
85+
# Generate random lengths that sum to the sequence length
86+
def generate_random_lengths(total_length, num_documents):
87+
# Initialize all lengths to 1 to ensure each document has at least one token
88+
lengths = [1] * num_documents
89+
remaining_length = total_length - num_documents
90+
91+
# Randomly distribute the remaining length
92+
for _ in range(remaining_length):
93+
index = random.randint(0, num_documents - 1)
94+
lengths[index] += 1
95+
96+
return lengths
97+
98+
lengths = generate_random_lengths(config.M, doc_count)
99+
offsets = length_to_offsets(lengths, device)
100+
return generate_doc_mask_mod(causal_mask, offsets)
101+
102+
case "dilated_sliding_window":
103+
window_size = config.M // 4
104+
dilation = 4
105+
return generate_dilated_sliding_window(window_size, dilation)
106+
case _:
107+
mod = MASK_MOD_MAP[name]
108+
return mod
109+
110+
111+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
112+
# Find the mask_mod function by name
113+
assert config.mask_mod_name in MASK_MOD_MAP, f"Mask mod '{config.mask_mod_name}' not found."
114+
mask_mod_fn = get_mask_mod(config)
115+
116+
# --- Time Benchmarking ---
117+
cbm = torch.compile(create_block_mask)
118+
# Warmup
119+
for _ in range(10):
120+
cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device)
121+
torch.cuda.synchronize(device)
122+
123+
creation_time_us = benchmark_cuda_function_in_microseconds(
124+
lambda: cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device),
125+
)
126+
127+
torch.cuda.synchronize(device)
128+
129+
with cuda_memory_usage() as mem:
130+
cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device)
131+
torch.cuda.synchronize(device)
132+
133+
return ExperimentResult(
134+
creation_time_ms=creation_time_us / 1000, memory_bytes=mem.memory_usage
135+
)
136+
137+
138+
def print_results(experiments: List[Experiment]):
139+
headers = [
140+
"B",
141+
"H",
142+
"M",
143+
"N",
144+
"Mask Mod",
145+
"Creation Time (ms)",
146+
"Memory (GiB)",
147+
]
148+
rows = []
149+
for experiment in experiments:
150+
rows.append(
151+
[
152+
experiment.config.B,
153+
experiment.config.H,
154+
experiment.config.M,
155+
experiment.config.N,
156+
experiment.config.mask_mod_name,
157+
f"{experiment.result.creation_time_ms:.4f}",
158+
f"{experiment.result.memory_bytes:.2f}",
159+
]
160+
)
161+
# Sort rows for better readability (e.g., by B, H, M, N)
162+
rows.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
163+
print(tabulate(rows, headers=headers, tablefmt="grid"))
164+
165+
166+
def get_configs() -> List[ExperimentConfig]:
167+
# Define ranges for benchmark parameters
168+
Bs = [1]
169+
Hs = [8]
170+
# Sequence lengths - adjust as needed
171+
# Using powers of 2 up to a reasonable limit for mask creation
172+
SeqLens = [8192, 16384, 32768]
173+
# Map string names to mask functions
174+
mask_mods_to_run = list(MASK_MOD_MAP.keys())
175+
176+
configs = []
177+
for B, H, S, mask_mod in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run):
178+
configs.append(
179+
ExperimentConfig(
180+
B=B,
181+
H=H,
182+
M=S, # Assuming M=N for simplicity
183+
N=S,
184+
mask_mod_name=mask_mod,
185+
)
186+
)
187+
return configs
188+
189+
190+
def main():
191+
torch.random.manual_seed(123)
192+
configs = get_configs()
193+
results = []
194+
print(f"Running {len(configs)} benchmark configurations...")
195+
for config in tqdm(configs):
196+
try:
197+
result = run_experiment(config)
198+
results.append(Experiment(config=config, result=result))
199+
except Exception as e:
200+
print(f"Failed to run config {config}: {e}")
201+
# Optionally skip failed configs or handle differently
202+
203+
# Use Tabulate to print results
204+
print_results(results)
205+
206+
207+
if __name__ == "__main__":
208+
main()

0 commit comments

Comments
 (0)