Skip to content

Commit 38111cc

Browse files
committed
[cp] set up load balancing testbed
ghstack-source-id: 9c1a7b9 Pull Request resolved: #120
1 parent af82ef0 commit 38111cc

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

attn_gym/load_balance/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from attn_gym.load_balance.load_balancer import load_balance_algo
2+
3+
__all__ = ["load_balance_algo"]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import List
2+
3+
import torch
4+
5+
6+
__all__ = ["load_balance_algo"]
7+
8+
9+
def load_balance_algo(S: int, size: int, block_size: int) -> torch.Tensor:
10+
total_num_blk = S // block_size
11+
assert S % (size * total_num_blk) == 0
12+
local_num_blk = total_num_blk // size
13+
return torch.arange(total_num_blk, device="cuda").view(size, local_num_blk)

examples/distributed_benchmark.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from functools import lru_cache
2+
from typing import Optional
3+
4+
import os
5+
import torch
6+
import torch.distributed as dist
7+
from torch.distributed.device_mesh import init_device_mesh
8+
from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Replicate, Shard
9+
10+
11+
from torch.nn.attention.flex_attention import (
12+
_DEFAULT_SPARSE_BLOCK_SIZE,
13+
create_block_mask,
14+
flex_attention,
15+
_mask_mod_signature,
16+
)
17+
18+
from attn_gym.masks.document_mask import length_to_offsets
19+
from attn_gym.masks import (
20+
causal_mask,
21+
generate_doc_mask_mod,
22+
)
23+
from attn_gym.load_balance import load_balance_algo
24+
25+
26+
def get_device_type() -> str:
27+
return "cuda"
28+
29+
30+
@lru_cache
31+
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
32+
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
33+
return block_mask
34+
35+
36+
# TODO: re-write it into a wrapper???
37+
def rewrite_mask_mod_for_cp(
38+
mask_mod: _mask_mod_signature,
39+
rank: int,
40+
block_size: int,
41+
load_balancer_output: torch.Tensor,
42+
) -> _mask_mod_signature:
43+
def local_q_idx_to_q_idx(local_q_idx) -> int:
44+
# calculate local block_idx and block_offset
45+
local_blk_idx, local_blk_offset = (
46+
local_q_idx // block_size, local_q_idx % block_size
47+
)
48+
current_rank_blk_list = load_balancer_output[rank]
49+
blk_idx = current_rank_blk_list[local_blk_idx]
50+
return blk_idx * block_size + local_blk_offset
51+
52+
return lambda b, h, q_idx, kv_idx: mask_mod(
53+
b, h, local_q_idx_to_q_idx(q_idx), kv_idx
54+
)
55+
56+
57+
def run_document_masking(device_mesh, max_seq_len, num_docs):
58+
# initialize the document lengths
59+
import random
60+
61+
random.seed(0)
62+
torch.cuda.manual_seed(0)
63+
64+
def generate_random_lengths(total_length, num_documents):
65+
# Initialize all lengths to 1 to ensure each document has at least one token
66+
lengths = [1] * num_documents
67+
remaining_length = total_length - num_documents
68+
69+
# Randomly distribute the remaining length
70+
for _ in range(remaining_length):
71+
index = random.randint(0, num_documents - 1)
72+
lengths[index] += 1
73+
74+
return lengths
75+
76+
lengths = generate_random_lengths(max_seq_len, num_docs)
77+
offsets = length_to_offsets(lengths, torch.device(f'cuda:{torch.cuda.current_device():d}')) # TODO: replace with a device mesh call
78+
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
79+
test_mask_with_load_balance(device_mesh, mask_mod=document_causal_mask, S=max_seq_len)
80+
81+
82+
def test_mask_with_load_balance(
83+
device_mesh: DeviceMesh,
84+
mask_mod: Optional[_mask_mod_signature] = None,
85+
B: int = 16,
86+
H: int = 16,
87+
S: int = 8192,
88+
D: int = 64,
89+
skip_correctness: bool = False,
90+
print_mask: bool = True,
91+
device: str = "cuda",
92+
):
93+
data_type = torch.float16
94+
95+
# create block mask
96+
block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device)
97+
block_size = _DEFAULT_SPARSE_BLOCK_SIZE # TODO: get block size from block mask
98+
99+
# input initialization
100+
qkv = [
101+
torch.rand(
102+
(B, H, S, D),
103+
device=device_mesh.device_type,
104+
dtype=data_type,
105+
requires_grad=True,
106+
)
107+
for _ in range(3)
108+
]
109+
110+
# TODO: input sharding with load-balancing
111+
# sparsity_info = get_sparsity_info_from_block_mask(block_mask)
112+
# load_balancer_output = load_balance_algo(sparsity_info)
113+
cp_mesh_size = device_mesh.size()
114+
load_balancer_output = load_balance_algo(S, cp_mesh_size, block_size)
115+
116+
seq_dim = 2
117+
qkv_dist = [
118+
distribute_tensor(
119+
t.detach().clone().requires_grad_(), device_mesh, [
120+
Shard(seq_dim) if i == 0 else Replicate()
121+
]
122+
)
123+
for (i, t) in enumerate(qkv)
124+
]
125+
126+
q_local, k_full, v_full = (dt.to_local() for dt in qkv_dist)
127+
128+
# rewrite `block_mask`
129+
mask_mod: _mask_mod_signature = block_mask.mask_mod
130+
cp_rank = device_mesh.get_local_rank()
131+
cp_mask_mod = rewrite_mask_mod_for_cp(
132+
mask_mod, cp_rank, block_size, load_balancer_output
133+
)
134+
cp_block_mask = create_block_mask_cached(
135+
cp_mask_mod, B=1, H=1, M=S // cp_mesh_size, N=S, device=device
136+
)
137+
138+
# Compile the flex_attention function
139+
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
140+
141+
# TODO: this doesn't address the return_lse=True case
142+
cp_out = compiled_flex_attention(
143+
q_local,
144+
k_full,
145+
v_full,
146+
score_mod=None,
147+
block_mask=cp_block_mask,
148+
)
149+
assert isinstance(cp_out, torch.Tensor)
150+
151+
# unshard
152+
cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)])
153+
full_cp_out_dist = cp_out_dist.full_tensor()
154+
# rearrange
155+
blk_idx_to_origin = load_balancer_output.view(-1)
156+
num_chunks = blk_idx_to_origin.numel()
157+
blk_list_rearranged = [None] * num_chunks
158+
blk_list = torch.chunk(full_cp_out_dist, num_chunks, dim=seq_dim)
159+
assert len(blk_list) == num_chunks
160+
for blk_idx, blk in enumerate(blk_list):
161+
blk_list_rearranged[blk_idx_to_origin[blk_idx].item()] = blk
162+
163+
full_cp_out_dist = torch.cat(blk_list_rearranged, dim=seq_dim)
164+
165+
# local flex attention
166+
expect_out = flex_attention(*qkv, block_mask=block_mask)
167+
torch.testing.assert_close(full_cp_out_dist, expect_out, atol=1e-1, rtol=1e-2)
168+
169+
170+
def load_balancing_example(world_size: int, rank: int) -> None:
171+
device_type = get_device_type()
172+
device_handle = getattr(torch, device_type, None)
173+
assert device_handle is not None, f"Unsupported device type: {device_type}"
174+
num_devices_per_host = device_handle.device_count()
175+
device_handle.set_device(rank % num_devices_per_host)
176+
torch._dynamo.config.cache_size_limit = 1000
177+
178+
# init device mesh
179+
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
180+
181+
run_document_masking(device_mesh, max_seq_len=4096, num_docs=12)
182+
183+
184+
if __name__ == "__main__":
185+
# this script is launched via torchrun which automatically manages ProcessGroup
186+
rank = int(os.environ["RANK"])
187+
world_size = int(os.environ["WORLD_SIZE"])
188+
# assert world_size == 4 # our example uses 4 worker ranks
189+
190+
try:
191+
load_balancing_example(world_size, rank)
192+
finally:
193+
dist.barrier()
194+
dist.destroy_process_group()

0 commit comments

Comments
 (0)