Skip to content

Commit 1b04fb3

Browse files
committed
Update
[ghstack-poisoned]
1 parent cc7d867 commit 1b04fb3

File tree

1 file changed

+51
-16
lines changed

1 file changed

+51
-16
lines changed

benchmarks/bench_block_mask.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,12 @@
11
import itertools
22
from dataclasses import dataclass
3-
from typing import List
3+
from typing import List, Callable
44

55
import torch
66
from tabulate import tabulate
77
from tqdm import tqdm
88
import random
9-
import sys
10-
from importlib.util import find_spec
11-
12-
# Check if transformer_nuggets is available
13-
if find_spec("transformer_nuggets") is None:
14-
print(
15-
"Need to install transformer_nuggets for this benchmark. "
16-
"Run `pip install git+https://github.com/drisspg/transformer_nuggets`"
17-
)
18-
sys.exit(1)
19-
20-
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, cuda_memory_usage
21-
9+
from triton.testing import do_bench
2210

2311
from attn_gym.masks import (
2412
causal_mask,
@@ -37,6 +25,45 @@
3725
torch._dynamo.config.cache_size_limit = 1000
3826

3927

28+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
29+
"""Thin wrapper around do_bench_using_profiling"""
30+
no_args = lambda: func(*args, **kwargs)
31+
time = do_bench(no_args)
32+
return time * 1e3
33+
34+
35+
class cuda_memory_usage:
36+
"""Prints the difference CUDA memory usage at the end of a context manager
37+
38+
Args:
39+
log (bool): Whether to print the memory usage to the console
40+
precision (int): The number of decimal places to print
41+
42+
Usage:
43+
```
44+
with cuda_memory_usage() as mem:
45+
# code to profile
46+
print(mem.memory_usage)
47+
```
48+
49+
"""
50+
51+
def __init__(self, log=False, precision=2):
52+
self.log = log
53+
self.precision = precision
54+
self.memory_usage = 0
55+
56+
def __enter__(self):
57+
self.initial_memory = torch.cuda.memory_allocated()
58+
return self
59+
60+
def __exit__(self, exc_type, exc_val, exc_tb):
61+
self.memory_usage = torch.cuda.memory_allocated() - self.initial_memory
62+
if self.log:
63+
memory_usage_gib = self.memory_usage / (1024**3)
64+
print(f"CUDA memory usage: {memory_usage_gib:.{self.precision}f} GiB")
65+
66+
4067
MASK_MOD_MAP = {
4168
"causal": causal_mask,
4269
"sliding_window": generate_sliding_window,
@@ -53,6 +80,7 @@ class ExperimentConfig:
5380
M: int
5481
N: int
5582
mask_mod_name: str
83+
dynamic: bool
5684

5785

5886
@dataclass(frozen=True)
@@ -113,7 +141,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
113141
assert config.mask_mod_name in MASK_MOD_MAP, f"Mask mod '{config.mask_mod_name}' not found."
114142
mask_mod_fn = get_mask_mod(config)
115143

116-
cbm = torch.compile(create_block_mask, dynamic=False)
144+
cbm = torch.compile(create_block_mask, dynamic=config.dynamic)
117145
# Warmup
118146
for _ in range(10):
119147
cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device)
@@ -141,6 +169,7 @@ def print_results(experiments: List[Experiment]):
141169
"M",
142170
"N",
143171
"Mask Mod",
172+
"Dynamic",
144173
"Creation Time (ms)",
145174
"Memory (GiB)",
146175
]
@@ -153,6 +182,7 @@ def print_results(experiments: List[Experiment]):
153182
experiment.config.M,
154183
experiment.config.N,
155184
experiment.config.mask_mod_name,
185+
experiment.config.dynamic,
156186
f"{experiment.result.creation_time_ms:.4f}",
157187
f"{experiment.result.memory_bytes:.2f}",
158188
]
@@ -169,23 +199,28 @@ def get_configs() -> List[ExperimentConfig]:
169199
SeqLens = [8192, 16384, 32768]
170200
# Map string names to mask functions
171201
mask_mods_to_run = list(MASK_MOD_MAP.keys())
202+
dynamic = [
203+
False,
204+
]
172205

173206
configs = []
174-
for B, H, S, mask_mod in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run):
207+
for B, H, S, mask_mod, dyn in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run, dynamic):
175208
configs.append(
176209
ExperimentConfig(
177210
B=B,
178211
H=H,
179212
M=S, # Assuming M=N for simplicity
180213
N=S,
181214
mask_mod_name=mask_mod,
215+
dynamic=dyn,
182216
)
183217
)
184218
return configs
185219

186220

187221
def main():
188222
torch.random.manual_seed(123)
223+
random.seed(123)
189224
configs = get_configs()
190225
results = []
191226
print(f"Running {len(configs)} benchmark configurations...")

0 commit comments

Comments
 (0)