|
| 1 | +import itertools |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import List, Callable |
| 4 | + |
| 5 | +import torch |
| 6 | +from tabulate import tabulate |
| 7 | +from tqdm import tqdm |
| 8 | +import random |
| 9 | +from triton.testing import do_bench |
| 10 | + |
| 11 | +from attn_gym.masks import ( |
| 12 | + causal_mask, |
| 13 | + generate_sliding_window, |
| 14 | + generate_prefix_lm_mask, |
| 15 | + generate_doc_mask_mod, |
| 16 | + generate_dilated_sliding_window, |
| 17 | +) |
| 18 | +from attn_gym.masks.document_mask import length_to_offsets |
| 19 | + |
| 20 | +from torch.nn.attention.flex_attention import create_block_mask, _mask_mod_signature |
| 21 | + |
| 22 | +device = torch.device("cuda") |
| 23 | + |
| 24 | +# Needed since changing args to function causes recompiles |
| 25 | +torch._dynamo.config.cache_size_limit = 1000 |
| 26 | + |
| 27 | + |
| 28 | +def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: |
| 29 | + """Thin wrapper around do_bench_using_profiling""" |
| 30 | + no_args = lambda: func(*args, **kwargs) |
| 31 | + time = do_bench(no_args) |
| 32 | + return time * 1e3 |
| 33 | + |
| 34 | + |
| 35 | +class cuda_memory_usage: |
| 36 | + """Prints the difference CUDA memory usage at the end of a context manager |
| 37 | +
|
| 38 | + Args: |
| 39 | + log (bool): Whether to print the memory usage to the console |
| 40 | + precision (int): The number of decimal places to print |
| 41 | +
|
| 42 | + Usage: |
| 43 | + ``` |
| 44 | + with cuda_memory_usage() as mem: |
| 45 | + # code to profile |
| 46 | + print(mem.memory_usage) |
| 47 | + ``` |
| 48 | +
|
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self, log=False, precision=2): |
| 52 | + self.log = log |
| 53 | + self.precision = precision |
| 54 | + self.memory_usage = 0 |
| 55 | + |
| 56 | + def __enter__(self): |
| 57 | + self.initial_memory = torch.cuda.memory_allocated() |
| 58 | + return self |
| 59 | + |
| 60 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 61 | + self.memory_usage = torch.cuda.memory_allocated() - self.initial_memory |
| 62 | + if self.log: |
| 63 | + memory_usage_gib = self.memory_usage / (1024**3) |
| 64 | + print(f"CUDA memory usage: {memory_usage_gib:.{self.precision}f} GiB") |
| 65 | + |
| 66 | + |
| 67 | +MASK_MOD_MAP = { |
| 68 | + "causal": causal_mask, |
| 69 | + "sliding_window": generate_sliding_window, |
| 70 | + "prefix_lm": generate_prefix_lm_mask, |
| 71 | + "doc_mask_mod": generate_doc_mask_mod, |
| 72 | + "dilated_sliding_window": generate_dilated_sliding_window, |
| 73 | +} |
| 74 | + |
| 75 | + |
| 76 | +@dataclass(frozen=True) |
| 77 | +class ExperimentConfig: |
| 78 | + B: int |
| 79 | + H: int |
| 80 | + M: int |
| 81 | + N: int |
| 82 | + mask_mod_name: str |
| 83 | + dynamic: bool |
| 84 | + |
| 85 | + |
| 86 | +@dataclass(frozen=True) |
| 87 | +class ExperimentResult: |
| 88 | + creation_time_ms: float |
| 89 | + memory_bytes: int |
| 90 | + |
| 91 | + |
| 92 | +@dataclass(frozen=True) |
| 93 | +class Experiment: |
| 94 | + config: ExperimentConfig |
| 95 | + result: ExperimentResult |
| 96 | + |
| 97 | + |
| 98 | +def get_mask_mod(config: ExperimentConfig) -> _mask_mod_signature: |
| 99 | + name = config.mask_mod_name |
| 100 | + match name: |
| 101 | + case "sliding_window": |
| 102 | + # Lets have window size be a 1/4 |
| 103 | + window_size = config.M // 4 |
| 104 | + return generate_sliding_window(window_size) |
| 105 | + case "prefix_lm": |
| 106 | + # Same for prefix length |
| 107 | + prefix_length = config.M // 4 |
| 108 | + return generate_prefix_lm_mask(prefix_length) |
| 109 | + case "doc_mask_mod": |
| 110 | + # Kinda random but at least 2 |
| 111 | + doc_count = max(2, config.M // 128) |
| 112 | + |
| 113 | + # Generate random lengths that sum to the sequence length |
| 114 | + def generate_random_lengths(total_length, num_documents): |
| 115 | + # Initialize all lengths to 1 to ensure each document has at least one token |
| 116 | + lengths = [1] * num_documents |
| 117 | + remaining_length = total_length - num_documents |
| 118 | + |
| 119 | + # Randomly distribute the remaining length |
| 120 | + for _ in range(remaining_length): |
| 121 | + index = random.randint(0, num_documents - 1) |
| 122 | + lengths[index] += 1 |
| 123 | + |
| 124 | + return lengths |
| 125 | + |
| 126 | + lengths = generate_random_lengths(config.M, doc_count) |
| 127 | + offsets = length_to_offsets(lengths, device) |
| 128 | + return generate_doc_mask_mod(causal_mask, offsets) |
| 129 | + |
| 130 | + case "dilated_sliding_window": |
| 131 | + window_size = config.M // 4 |
| 132 | + dilation = 4 |
| 133 | + return generate_dilated_sliding_window(window_size, dilation) |
| 134 | + case _: |
| 135 | + mod = MASK_MOD_MAP[name] |
| 136 | + return mod |
| 137 | + |
| 138 | + |
| 139 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 140 | + # Find the mask_mod function by name |
| 141 | + assert config.mask_mod_name in MASK_MOD_MAP, f"Mask mod '{config.mask_mod_name}' not found." |
| 142 | + mask_mod_fn = get_mask_mod(config) |
| 143 | + |
| 144 | + cbm = torch.compile(create_block_mask, dynamic=config.dynamic) |
| 145 | + # Warmup |
| 146 | + for _ in range(10): |
| 147 | + cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device) |
| 148 | + torch.cuda.synchronize(device) |
| 149 | + |
| 150 | + creation_time_us = benchmark_cuda_function_in_microseconds( |
| 151 | + lambda: cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device), |
| 152 | + ) |
| 153 | + |
| 154 | + torch.cuda.synchronize(device) |
| 155 | + |
| 156 | + with cuda_memory_usage() as mem: |
| 157 | + cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device) |
| 158 | + torch.cuda.synchronize(device) |
| 159 | + |
| 160 | + return ExperimentResult( |
| 161 | + creation_time_ms=creation_time_us / 1000, memory_bytes=mem.memory_usage |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +def print_results(experiments: List[Experiment]): |
| 166 | + headers = [ |
| 167 | + "B", |
| 168 | + "H", |
| 169 | + "M", |
| 170 | + "N", |
| 171 | + "Mask Mod", |
| 172 | + "Dynamic", |
| 173 | + "Creation Time (ms)", |
| 174 | + "Memory (GiB)", |
| 175 | + ] |
| 176 | + rows = [] |
| 177 | + for experiment in experiments: |
| 178 | + rows.append( |
| 179 | + [ |
| 180 | + experiment.config.B, |
| 181 | + experiment.config.H, |
| 182 | + experiment.config.M, |
| 183 | + experiment.config.N, |
| 184 | + experiment.config.mask_mod_name, |
| 185 | + experiment.config.dynamic, |
| 186 | + f"{experiment.result.creation_time_ms:.4f}", |
| 187 | + f"{experiment.result.memory_bytes:.2f}", |
| 188 | + ] |
| 189 | + ) |
| 190 | + # Sort rows for better readability (e.g., by B, H, M, N) |
| 191 | + rows.sort(key=lambda x: (x[0], x[1], x[2], x[3])) |
| 192 | + print(tabulate(rows, headers=headers, tablefmt="grid")) |
| 193 | + |
| 194 | + |
| 195 | +def get_configs() -> List[ExperimentConfig]: |
| 196 | + # Define ranges for benchmark parameters |
| 197 | + Bs = [1] |
| 198 | + Hs = [8] |
| 199 | + SeqLens = [8192, 16384, 32768] |
| 200 | + # Map string names to mask functions |
| 201 | + mask_mods_to_run = list(MASK_MOD_MAP.keys()) |
| 202 | + dynamic = [ |
| 203 | + False, |
| 204 | + ] |
| 205 | + |
| 206 | + configs = [] |
| 207 | + for B, H, S, mask_mod, dyn in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run, dynamic): |
| 208 | + configs.append( |
| 209 | + ExperimentConfig( |
| 210 | + B=B, |
| 211 | + H=H, |
| 212 | + M=S, # Assuming M=N for simplicity |
| 213 | + N=S, |
| 214 | + mask_mod_name=mask_mod, |
| 215 | + dynamic=dyn, |
| 216 | + ) |
| 217 | + ) |
| 218 | + return configs |
| 219 | + |
| 220 | + |
| 221 | +def main(): |
| 222 | + torch.random.manual_seed(123) |
| 223 | + random.seed(123) |
| 224 | + configs = get_configs() |
| 225 | + results = [] |
| 226 | + print(f"Running {len(configs)} benchmark configurations...") |
| 227 | + for config in tqdm(configs): |
| 228 | + try: |
| 229 | + result = run_experiment(config) |
| 230 | + results.append(Experiment(config=config, result=result)) |
| 231 | + except Exception as e: |
| 232 | + print(f"Failed to run config {config}: {e}") |
| 233 | + |
| 234 | + print_results(results) |
| 235 | + |
| 236 | + |
| 237 | +if __name__ == "__main__": |
| 238 | + main() |
0 commit comments