1
1
import itertools
2
2
from dataclasses import dataclass
3
- from typing import List
3
+ from typing import List , Callable
4
4
5
5
import torch
6
6
from tabulate import tabulate
7
7
from tqdm import tqdm
8
8
import random
9
- import sys
10
- from importlib .util import find_spec
11
-
12
- # Check if transformer_nuggets is available
13
- if find_spec ("transformer_nuggets" ) is None :
14
- print (
15
- "Need to install transformer_nuggets for this benchmark. "
16
- "Run `pip install git+https://github.com/drisspg/transformer_nuggets`"
17
- )
18
- sys .exit (1 )
19
-
20
- from transformer_nuggets .utils import benchmark_cuda_function_in_microseconds , cuda_memory_usage
21
-
9
+ from triton .testing import do_bench
22
10
23
11
from attn_gym .masks import (
24
12
causal_mask ,
37
25
torch ._dynamo .config .cache_size_limit = 1000
38
26
39
27
28
+ def benchmark_cuda_function_in_microseconds (func : Callable , * args , ** kwargs ) -> float :
29
+ """Thin wrapper around do_bench_using_profiling"""
30
+ no_args = lambda : func (* args , ** kwargs )
31
+ time = do_bench (no_args )
32
+ return time * 1e3
33
+
34
+
35
+ class cuda_memory_usage :
36
+ """Prints the difference CUDA memory usage at the end of a context manager
37
+
38
+ Args:
39
+ log (bool): Whether to print the memory usage to the console
40
+ precision (int): The number of decimal places to print
41
+
42
+ Usage:
43
+ ```
44
+ with cuda_memory_usage() as mem:
45
+ # code to profile
46
+ print(mem.memory_usage)
47
+ ```
48
+
49
+ """
50
+
51
+ def __init__ (self , log = False , precision = 2 ):
52
+ self .log = log
53
+ self .precision = precision
54
+ self .memory_usage = 0
55
+
56
+ def __enter__ (self ):
57
+ self .initial_memory = torch .cuda .memory_allocated ()
58
+ return self
59
+
60
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
61
+ self .memory_usage = torch .cuda .memory_allocated () - self .initial_memory
62
+ if self .log :
63
+ memory_usage_gib = self .memory_usage / (1024 ** 3 )
64
+ print (f"CUDA memory usage: { memory_usage_gib :.{self .precision }f} GiB" )
65
+
66
+
40
67
MASK_MOD_MAP = {
41
68
"causal" : causal_mask ,
42
69
"sliding_window" : generate_sliding_window ,
@@ -53,6 +80,7 @@ class ExperimentConfig:
53
80
M : int
54
81
N : int
55
82
mask_mod_name : str
83
+ dynamic : bool
56
84
57
85
58
86
@dataclass (frozen = True )
@@ -113,7 +141,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
113
141
assert config .mask_mod_name in MASK_MOD_MAP , f"Mask mod '{ config .mask_mod_name } ' not found."
114
142
mask_mod_fn = get_mask_mod (config )
115
143
116
- cbm = torch .compile (create_block_mask , dynamic = False )
144
+ cbm = torch .compile (create_block_mask , dynamic = config . dynamic )
117
145
# Warmup
118
146
for _ in range (10 ):
119
147
cbm (mask_mod_fn , config .B , config .H , config .M , config .N , device = device )
@@ -141,6 +169,7 @@ def print_results(experiments: List[Experiment]):
141
169
"M" ,
142
170
"N" ,
143
171
"Mask Mod" ,
172
+ "Dynamic" ,
144
173
"Creation Time (ms)" ,
145
174
"Memory (GiB)" ,
146
175
]
@@ -153,6 +182,7 @@ def print_results(experiments: List[Experiment]):
153
182
experiment .config .M ,
154
183
experiment .config .N ,
155
184
experiment .config .mask_mod_name ,
185
+ experiment .config .dynamic ,
156
186
f"{ experiment .result .creation_time_ms :.4f} " ,
157
187
f"{ experiment .result .memory_bytes :.2f} " ,
158
188
]
@@ -169,23 +199,28 @@ def get_configs() -> List[ExperimentConfig]:
169
199
SeqLens = [8192 , 16384 , 32768 ]
170
200
# Map string names to mask functions
171
201
mask_mods_to_run = list (MASK_MOD_MAP .keys ())
202
+ dynamic = [
203
+ False ,
204
+ ]
172
205
173
206
configs = []
174
- for B , H , S , mask_mod in itertools .product (Bs , Hs , SeqLens , mask_mods_to_run ):
207
+ for B , H , S , mask_mod , dyn in itertools .product (Bs , Hs , SeqLens , mask_mods_to_run , dynamic ):
175
208
configs .append (
176
209
ExperimentConfig (
177
210
B = B ,
178
211
H = H ,
179
212
M = S , # Assuming M=N for simplicity
180
213
N = S ,
181
214
mask_mod_name = mask_mod ,
215
+ dynamic = dyn ,
182
216
)
183
217
)
184
218
return configs
185
219
186
220
187
221
def main ():
188
222
torch .random .manual_seed (123 )
223
+ random .seed (123 )
189
224
configs = get_configs ()
190
225
results = []
191
226
print (f"Running { len (configs )} benchmark configurations..." )
0 commit comments