|
| 1 | +import itertools |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import List |
| 4 | + |
| 5 | +import torch |
| 6 | +from tabulate import tabulate |
| 7 | +from tqdm import tqdm |
| 8 | +import random |
| 9 | +import sys |
| 10 | +from importlib.util import find_spec |
| 11 | + |
| 12 | +# Check if transformer_nuggets is available |
| 13 | +if find_spec("transformer_nuggets") is None: |
| 14 | + print( |
| 15 | + "Need to install transformer_nuggets for this benchmark. " |
| 16 | + "Run `pip install git+https://github.com/drisspg/transformer_nuggets`" |
| 17 | + ) |
| 18 | + sys.exit(1) |
| 19 | + |
| 20 | +from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, cuda_memory_usage |
| 21 | + |
| 22 | + |
| 23 | +from attn_gym.masks import ( |
| 24 | + causal_mask, |
| 25 | + generate_sliding_window, |
| 26 | + generate_prefix_lm_mask, |
| 27 | + generate_doc_mask_mod, |
| 28 | + generate_dilated_sliding_window, |
| 29 | +) |
| 30 | +from attn_gym.masks.document_mask import length_to_offsets |
| 31 | + |
| 32 | +from torch.nn.attention.flex_attention import create_block_mask, _mask_mod_signature |
| 33 | + |
| 34 | +device = torch.device("cuda") |
| 35 | + |
| 36 | +# Needed since changing args to function causes recompiles |
| 37 | +torch._dynamo.config.cache_size_limit = 1000 |
| 38 | + |
| 39 | + |
| 40 | +MASK_MOD_MAP = { |
| 41 | + "causal": causal_mask, |
| 42 | + "sliding_window": generate_sliding_window, |
| 43 | + "prefix_lm": generate_prefix_lm_mask, |
| 44 | + "doc_mask_mod": generate_doc_mask_mod, |
| 45 | + "dilated_sliding_window": generate_dilated_sliding_window, |
| 46 | +} |
| 47 | + |
| 48 | + |
| 49 | +@dataclass(frozen=True) |
| 50 | +class ExperimentConfig: |
| 51 | + B: int |
| 52 | + H: int |
| 53 | + M: int |
| 54 | + N: int |
| 55 | + mask_mod_name: str |
| 56 | + |
| 57 | + |
| 58 | +@dataclass(frozen=True) |
| 59 | +class ExperimentResult: |
| 60 | + creation_time_ms: float |
| 61 | + memory_bytes: int |
| 62 | + |
| 63 | + |
| 64 | +@dataclass(frozen=True) |
| 65 | +class Experiment: |
| 66 | + config: ExperimentConfig |
| 67 | + result: ExperimentResult |
| 68 | + |
| 69 | + |
| 70 | +def get_mask_mod(config: ExperimentConfig) -> _mask_mod_signature: |
| 71 | + name = config.mask_mod_name |
| 72 | + match name: |
| 73 | + case "sliding_window": |
| 74 | + # Lets have window size be a 1/4 |
| 75 | + window_size = config.M // 4 |
| 76 | + return generate_sliding_window(window_size) |
| 77 | + case "prefix_lm": |
| 78 | + # Same for prefix length |
| 79 | + prefix_length = config.M // 4 |
| 80 | + return generate_prefix_lm_mask(prefix_length) |
| 81 | + case "doc_mask_mod": |
| 82 | + # Kinda random but at least 2 |
| 83 | + doc_count = max(2, config.M // 128) |
| 84 | + |
| 85 | + # Generate random lengths that sum to the sequence length |
| 86 | + def generate_random_lengths(total_length, num_documents): |
| 87 | + # Initialize all lengths to 1 to ensure each document has at least one token |
| 88 | + lengths = [1] * num_documents |
| 89 | + remaining_length = total_length - num_documents |
| 90 | + |
| 91 | + # Randomly distribute the remaining length |
| 92 | + for _ in range(remaining_length): |
| 93 | + index = random.randint(0, num_documents - 1) |
| 94 | + lengths[index] += 1 |
| 95 | + |
| 96 | + return lengths |
| 97 | + |
| 98 | + lengths = generate_random_lengths(config.M, doc_count) |
| 99 | + offsets = length_to_offsets(lengths, device) |
| 100 | + return generate_doc_mask_mod(causal_mask, offsets) |
| 101 | + |
| 102 | + case "dilated_sliding_window": |
| 103 | + window_size = config.M // 4 |
| 104 | + dilation = 4 |
| 105 | + return generate_dilated_sliding_window(window_size, dilation) |
| 106 | + case _: |
| 107 | + mod = MASK_MOD_MAP[name] |
| 108 | + return mod |
| 109 | + |
| 110 | + |
| 111 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 112 | + # Find the mask_mod function by name |
| 113 | + assert config.mask_mod_name in MASK_MOD_MAP, f"Mask mod '{config.mask_mod_name}' not found." |
| 114 | + mask_mod_fn = get_mask_mod(config) |
| 115 | + |
| 116 | + cbm = torch.compile(create_block_mask, dynamic=False) |
| 117 | + # Warmup |
| 118 | + for _ in range(10): |
| 119 | + cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device) |
| 120 | + torch.cuda.synchronize(device) |
| 121 | + |
| 122 | + creation_time_us = benchmark_cuda_function_in_microseconds( |
| 123 | + lambda: cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device), |
| 124 | + ) |
| 125 | + |
| 126 | + torch.cuda.synchronize(device) |
| 127 | + |
| 128 | + with cuda_memory_usage() as mem: |
| 129 | + cbm(mask_mod_fn, config.B, config.H, config.M, config.N, device=device) |
| 130 | + torch.cuda.synchronize(device) |
| 131 | + |
| 132 | + return ExperimentResult( |
| 133 | + creation_time_ms=creation_time_us / 1000, memory_bytes=mem.memory_usage |
| 134 | + ) |
| 135 | + |
| 136 | + |
| 137 | +def print_results(experiments: List[Experiment]): |
| 138 | + headers = [ |
| 139 | + "B", |
| 140 | + "H", |
| 141 | + "M", |
| 142 | + "N", |
| 143 | + "Mask Mod", |
| 144 | + "Creation Time (ms)", |
| 145 | + "Memory (GiB)", |
| 146 | + ] |
| 147 | + rows = [] |
| 148 | + for experiment in experiments: |
| 149 | + rows.append( |
| 150 | + [ |
| 151 | + experiment.config.B, |
| 152 | + experiment.config.H, |
| 153 | + experiment.config.M, |
| 154 | + experiment.config.N, |
| 155 | + experiment.config.mask_mod_name, |
| 156 | + f"{experiment.result.creation_time_ms:.4f}", |
| 157 | + f"{experiment.result.memory_bytes:.2f}", |
| 158 | + ] |
| 159 | + ) |
| 160 | + # Sort rows for better readability (e.g., by B, H, M, N) |
| 161 | + rows.sort(key=lambda x: (x[0], x[1], x[2], x[3])) |
| 162 | + print(tabulate(rows, headers=headers, tablefmt="grid")) |
| 163 | + |
| 164 | + |
| 165 | +def get_configs() -> List[ExperimentConfig]: |
| 166 | + # Define ranges for benchmark parameters |
| 167 | + Bs = [1] |
| 168 | + Hs = [8] |
| 169 | + SeqLens = [8192, 16384, 32768] |
| 170 | + # Map string names to mask functions |
| 171 | + mask_mods_to_run = list(MASK_MOD_MAP.keys()) |
| 172 | + |
| 173 | + configs = [] |
| 174 | + for B, H, S, mask_mod in itertools.product(Bs, Hs, SeqLens, mask_mods_to_run): |
| 175 | + configs.append( |
| 176 | + ExperimentConfig( |
| 177 | + B=B, |
| 178 | + H=H, |
| 179 | + M=S, # Assuming M=N for simplicity |
| 180 | + N=S, |
| 181 | + mask_mod_name=mask_mod, |
| 182 | + ) |
| 183 | + ) |
| 184 | + return configs |
| 185 | + |
| 186 | + |
| 187 | +def main(): |
| 188 | + torch.random.manual_seed(123) |
| 189 | + configs = get_configs() |
| 190 | + results = [] |
| 191 | + print(f"Running {len(configs)} benchmark configurations...") |
| 192 | + for config in tqdm(configs): |
| 193 | + try: |
| 194 | + result = run_experiment(config) |
| 195 | + results.append(Experiment(config=config, result=result)) |
| 196 | + except Exception as e: |
| 197 | + print(f"Failed to run config {config}: {e}") |
| 198 | + |
| 199 | + print_results(results) |
| 200 | + |
| 201 | + |
| 202 | +if __name__ == "__main__": |
| 203 | + main() |
0 commit comments