Skip to content

Commit efba5c4

Browse files
committed
add context parallelism
1 parent 2ca55ae commit efba5c4

File tree

10 files changed

+496
-118
lines changed

10 files changed

+496
-118
lines changed

MaxText/common_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
BATCH = "activation_batch"
3838
LENGTH = "activation_length"
39+
Q_LENGTH = "activation_q_length"
3940
KV_LENGTH = "activation_kv_length"
4041
EMBED = "activation_embed"
4142
HEAD = "activation_heads"

MaxText/configs/base.yml

+19-17
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ logical_axis_rules: [
276276
['activation_length', ['sequence', 'context']],
277277
['activation_length', ['context']],
278278
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
279+
['activation_q_length', ['context']],
280+
['activation_kv_length', []],
279281
['activation_embed', ['tensor', 'tensor_transpose']],
280282
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
281283
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
@@ -285,7 +287,7 @@ logical_axis_rules: [
285287
['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
286288
['activation_vocab', ['tensor', 'tensor_transpose']],
287289
['activation_vocab', 'tensor_sequence'],
288-
['activation_vocab', ['sequence']],
290+
['activation_vocab', ['sequence','context']],
289291
['activation_stage', 'stage'],
290292
['activation_exp', ['expert']],
291293
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
@@ -295,22 +297,22 @@ logical_axis_rules: [
295297
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
296298
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
297299
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
298-
['embed', ['fsdp', 'fsdp_transpose', 'context', 'sequence', 'tensor_transpose', 'expert']],
299-
['embed', ['fsdp', 'context', 'sequence', 'tensor_transpose', 'expert']],
300-
['embed', ['fsdp', 'fsdp_transpose', 'context', 'sequence', 'expert']],
301-
['embed', ['fsdp', 'context', 'sequence', 'expert']],
302-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'context', 'sequence', 'tensor_transpose']],
303-
['embed_no_exp', ['fsdp', 'context', 'sequence', 'tensor_transpose']],
304-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'context', 'sequence']],
305-
['embed_no_exp', ['fsdp', 'context', 'sequence']],
306-
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
307-
['q_lora', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
308-
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
309-
['q_lora', ['fsdp', 'sequence', 'expert']],
310-
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
311-
['kv_lora', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
312-
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
313-
['kv_lora', ['fsdp', 'sequence', 'expert']],
300+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
301+
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
302+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
303+
['embed', ['fsdp', 'sequence', 'context', 'expert']],
304+
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
305+
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
306+
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
307+
['embed_no_exp', ['fsdp', 'sequence', 'context']],
308+
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
309+
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
310+
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
311+
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
312+
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
313+
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
314+
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
315+
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
314316
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
315317
['layers', 'stage'],
316318
['kv', []],

MaxText/configs/inference.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ logical_axis_rules: [
88
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
99
['activation_length', ['context_autoregressive', 'sequence']],
1010
['activation_length', ['context_autoregressive']],
11-
['activation_length_q', ['context_autoregressive']],
11+
['activation_q_length', ['context_autoregressive']],
1212
['activation_kv_length', ['context_autoregressive']],
1313
['activation_norm_length', ['tensor_sequence', 'sequence']],
1414
['activation_embed', ['tensor_transpose']],

MaxText/layers/attentions.py

+202-47
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
from MaxText.layers import linears
4343
from MaxText.layers import quantizations
4444

45+
from MaxText import max_logging
46+
from MaxText import max_utils
47+
import numpy as np
48+
49+
partial = functools.partial
4550

4651
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
4752
# pytype: disable=attribute-error
@@ -75,6 +80,7 @@ class AttentionType(enum.Enum):
7580
DECODE_BATCH = common_types.DECODE_BATCH
7681
DECODE_LENGTH = common_types.DECODE_LENGTH
7782
LENGTH = common_types.LENGTH
83+
Q_LENGTH = common_types.Q_LENGTH
7884
KV_LENGTH = common_types.KV_LENGTH
7985
HEAD = common_types.HEAD
8086
EMBED = common_types.EMBED
@@ -290,7 +296,12 @@ class AttentionOp(nn.Module):
290296
float32_qk_product: bool = False
291297
max_prefill_predict_length: int = -1
292298
float32_logits: bool = False
293-
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
299+
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV)
300+
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
301+
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
302+
prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV)
303+
cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV)
304+
cache_scale_logical_axis_names: AxisNames = (CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV)
294305
ragged_qkv_axis_names: AxisNames = (CACHE_BATCH, CACHE_HEADS, CACHE_SEQUENCE, CACHE_KV)
295306
ragged_lengths_names: AxisNames = (CACHE_BATCH,)
296307
compute_axis_order: AxisIdxes = (0, 1, 2, 3)
@@ -557,15 +568,25 @@ def tpu_flash_attention(
557568
attn_logits_soft_cap: float | None = None,
558569
) -> Array:
559570
"""TPU Flash Attention."""
571+
572+
cp_size = self.mesh.shape["context"]
573+
load_balanced_context_parallel = self.config.context_parallel_load_balance
574+
560575
# Transpose to ('batch', 'heads', 'length', 'kv')
561576
query = jnp.transpose(query, axes=(0, 2, 1, 3))
562577
key = jnp.transpose(key, axes=(0, 2, 1, 3))
563578
value = jnp.transpose(value, axes=(0, 2, 1, 3))
564-
579+
segment_axis_names_q = None
580+
segment_axis_names_kv = None
565581
if decoder_segment_ids is not None:
566-
decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids, decoder_segment_ids)
567-
axis_names = nn.logical_to_mesh_axes(self.flash_axis_names)
568-
segment_axis_names = nn.logical_to_mesh_axes((BATCH, "activation_length_no_heads"))
582+
segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, Q_LENGTH))
583+
segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, KV_LENGTH))
584+
axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
585+
axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q)
586+
axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv)
587+
max_logging.log(f"axis_names_q: {axis_names_q}")
588+
max_logging.log(f"axis_names_kv: {axis_names_kv}")
589+
max_logging.log(f"axis_names_splash_kernel: {axis_names_splash_kernel}")
569590

570591
global_block_q = self.config.sa_block_q
571592
global_block_kv = self.config.sa_block_kv
@@ -580,40 +601,46 @@ def tpu_flash_attention(
580601
global_k_layout = self.config.sa_k_layout
581602
global_v_layout = self.config.sa_v_layout
582603

583-
@functools.partial(
584-
shard_map,
585-
mesh=self.mesh,
586-
in_specs=(
587-
axis_names,
588-
axis_names,
589-
axis_names,
590-
segment_axis_names,
591-
),
592-
out_specs=axis_names,
593-
check_rep=False,
604+
devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"]
605+
assert (query.shape[0] / devices_in_data_fsdp).is_integer(), (
606+
"Batch dimension should be shardable among the devices in data and fsdp"
607+
" axis"
608+
f" got {query.shape[0]=}/{devices_in_data_fsdp=}"
594609
)
595-
def wrap_flash_attention(query, key, value, decoder_segment_ids):
596-
if decoder_segment_ids is not None:
597-
assert (
598-
query.shape[2] == decoder_segment_ids.q.shape[1]
599-
), "Sharding along sequence dimension not allowed in tpu kernel attention"
600-
block_sizes = splash_attention_kernel.BlockSizes(
601-
block_q=min(global_block_q, query.shape[2]),
602-
block_kv=min(global_block_kv, key.shape[2]),
603-
block_kv_compute=min(global_block_kv_compute, key.shape[2]),
604-
block_q_dkv=min(global_block_q_dkv, query.shape[2]),
605-
block_kv_dkv=min(global_block_kv_dkv, key.shape[2]),
606-
block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]),
607-
block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]),
608-
block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]),
609-
use_fused_bwd_kernel=global_use_fused_bwd_kernel,
610-
q_layout=splash_attention_kernel.QKVLayout[global_q_layout],
611-
k_layout=splash_attention_kernel.QKVLayout[global_k_layout],
612-
v_layout=splash_attention_kernel.QKVLayout[global_v_layout],
613-
)
614610

615-
mask = splash_attention_mask.CausalMask(shape=(query.shape[2], query.shape[2]))
611+
# create_splash_attention kernel
612+
block_sizes = splash_attention_kernel.BlockSizes(
613+
block_q=min(global_block_q, query.shape[2]),
614+
block_kv=min(global_block_kv, key.shape[2]),
615+
block_kv_compute=min(global_block_kv_compute, key.shape[2]),
616+
block_q_dkv=min(global_block_q_dkv, query.shape[2]),
617+
block_kv_dkv=min(global_block_kv_dkv, key.shape[2]),
618+
block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]),
619+
block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]),
620+
block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]),
621+
use_fused_bwd_kernel=global_use_fused_bwd_kernel,
622+
q_layout=splash_attention_kernel.QKVLayout[global_q_layout],
623+
k_layout=splash_attention_kernel.QKVLayout[global_k_layout],
624+
v_layout=splash_attention_kernel.QKVLayout[global_v_layout],
625+
)
626+
627+
mask_shape = (self.config.max_target_length, self.config.max_target_length)
628+
mask = splash_attention_mask.CausalMask(shape=mask_shape)
629+
630+
# Create LoadBalancedCausalMask if cp and load_balancing
631+
if cp_size > 1 and load_balanced_context_parallel:
632+
mask = LoadBalancedCausalMask(shape=mask_shape, cp_size=cp_size)
616633

634+
# TODO: figure out local_sliding attention + load_balancing, default is global
635+
# Apply local masking if local sliding attention is enabled.
636+
if self.attention_type == AttentionType.LOCAL_SLIDING:
637+
if self.sliding_window_size is None:
638+
raise ValueError("Sliding_window_size must be set if Local Sliding attention type")
639+
mask &= splash_attention_mask.LocalMask(
640+
shape=(query.shape[2], key.shape[2]),
641+
window_size=(self.sliding_window_size, self.sliding_window_size),
642+
offset=0,
643+
)
617644
# Apply local masking if local sliding attention is enabled.
618645
if self.attention_type == AttentionType.LOCAL_SLIDING:
619646
if self.sliding_window_size is None:
@@ -629,26 +656,100 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids):
629656

630657
mask &= ChunkedCausalMask(shape=(query.shape[2], key.shape[2]), chunk_size=self.chunk_attn_window_size)
631658

632-
# Create multi-head mask
633-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
659+
# Create multi-head mask
660+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
661+
662+
# Create the splash attention kernel object separately, jit it for performance
663+
@partial(
664+
jax.jit,
665+
static_argnames=[
666+
"multi_head_mask",
667+
"shard_head_size",
668+
],
669+
)
670+
def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
634671
splash_kernel = splash_attention_kernel.make_splash_mha(
635672
mask=multi_head_mask,
636-
head_shards=1,
637-
q_seq_shards=1,
673+
head_shards=shard_head_size, # the size of the axis if sharding over heads
674+
q_seq_shards=cp_size, # axis for sequence sharding
638675
block_sizes=block_sizes,
639676
attn_logits_soft_cap=attn_logits_soft_cap,
640677
)
678+
return splash_kernel
641679

642-
return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids)
680+
logical_axis_rules_head = np.array(
681+
[self.mesh.shape[physical_axes] for physical_axes in dict(self.config.logical_axis_rules)[HEAD]]
682+
)
683+
shard_head_size = np.prod(logical_axis_rules_head)
684+
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
685+
named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel)
686+
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
687+
688+
# Now call the function wrap_flash_attention which does the actual computation.
689+
# The splash kernel is passed as a parameter to the function. Since we have the shard map
690+
# decorating the wrap_flash_attention function, the data will be correctly sharded
691+
# meaning q will be sharded over sequence aka context length but K and V will be duplicated
692+
# The shardings are specified in the in_specs and out_specs of the shard_map decorator:
693+
# 'segment_axis_names_q' maps to ['activation_q_length', ['context']] meaning that q is sharded over the context axis
694+
# 'segment_axis_names_kv' maps to ['activation_kv_length', []] meaning that K and V are not sharded
695+
# splash_kernel is sharded over (HEAD, LENGTH)
696+
@functools.partial(
697+
shard_map,
698+
mesh=self.mesh,
699+
in_specs=(
700+
axis_names_q,
701+
axis_names_kv,
702+
axis_names_kv,
703+
segment_axis_names_q,
704+
segment_axis_names_kv,
705+
segment_axis_names_splash_kernel,
706+
None, # no sharding for cp_size
707+
None, # no sharding for load_balanced_context_parallel
708+
),
709+
out_specs=axis_names_q,
710+
check_rep=False,
711+
)
712+
def wrap_flash_attention(
713+
query,
714+
key,
715+
value,
716+
decoder_segment_ids_q,
717+
decoder_segment_ids_kv,
718+
splash_kernel,
719+
cp_size,
720+
load_balanced_context_parallel,
721+
):
722+
# If load_balanced_context_parallel is enabled, reorder the key and value tensors
723+
# to ensure that they are contiguous in memory.
724+
# This is necessary for the splash attention kernel to work correctly because it expects
725+
# the K and V to be contiguous. Note that K and V are not sharded over the sequence aka context axis
726+
# This was we get the unsharded unpermuted key and value tensors
727+
if cp_size > 1 and load_balanced_context_parallel:
728+
key = max_utils.reorder_sequence(tensor=key, cp_size=cp_size, seq_dim=2, to_contiguous=True)
729+
value = max_utils.reorder_sequence(tensor=value, cp_size=cp_size, seq_dim=2, to_contiguous=True)
730+
decoder_segment_ids_unpermuted = max_utils.reorder_sequence(
731+
tensor=decoder_segment_ids_kv, cp_size=cp_size, seq_dim=1, to_contiguous=True
732+
)
643733

644-
devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"]
645-
assert (query.shape[0] / devices_in_data_fsdp).is_integer(), (
646-
"Batch dimension should be shardable among the devices in data and fsdp"
647-
" axis"
648-
f" got {query.shape[0]=}/{devices_in_data_fsdp=}"
734+
if decoder_segment_ids_q is not None:
735+
if cp_size > 1 and load_balanced_context_parallel:
736+
decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(
737+
decoder_segment_ids_q, decoder_segment_ids_unpermuted
738+
)
739+
else:
740+
decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_q)
741+
else:
742+
decoder_segment_ids_tuple = None
743+
attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids_tuple)
744+
745+
return attention_output
746+
747+
x = wrap_flash_attention(
748+
query, key, value, decoder_segment_ids, decoder_segment_ids, splash_kernel, cp_size, load_balanced_context_parallel
649749
)
650-
x = wrap_flash_attention(query, key, value, decoder_segment_ids)
750+
651751
x = jnp.transpose(x, axes=(0, 2, 1, 3))
752+
652753
return x
653754

654755
def cudnn_flash_attention(
@@ -1673,3 +1774,57 @@ def __call__(
16731774
out = nn.with_logical_constraint(out, self.out_axis_names)
16741775
out = self.out_projection(inputs_q.shape[-1], out)
16751776
return out
1777+
1778+
1779+
class LoadBalancedCausalMask(splash_attention_mask._ComputableMask):
1780+
"""Lazy causal mask, prevents the model from attending to future tokens.
1781+
Attributes:
1782+
offset: Offset of q start wrt kv. A positive offset shifts the bottom
1783+
triangle upward, a negative one shifts it downward. A negative offset
1784+
makes the first 'offset' rows of the attention matrix all 0s which leads
1785+
to undefined softmax.
1786+
"""
1787+
1788+
offset: int
1789+
shape: tuple[int, int]
1790+
cp_size: int
1791+
1792+
def __init__(self, shape: tuple[int, int], offset: int = 0, shard_count: int = 1, cp_size: int = 4):
1793+
self.offset = offset
1794+
1795+
def causal_mask_function(q_ids, kv_ids):
1796+
if self.offset == 0:
1797+
return q_ids >= kv_ids
1798+
else:
1799+
return q_ids + self.offset >= kv_ids
1800+
1801+
arr = np.arange(shape[0])
1802+
# we reorder the mask to be load balanced following the same approach as
1803+
# used to reorder the input tokens
1804+
out = max_utils.reorder_mask_load_balancing(arr[None, :, None, None], cp_size, seq_dim=1)
1805+
q_sequence = out[0, :, 0, 0]
1806+
1807+
mask_function = causal_mask_function
1808+
1809+
super().__init__(
1810+
shape=shape,
1811+
mask_function=mask_function,
1812+
shard_count=shard_count,
1813+
)
1814+
self.q_sequence = q_sequence
1815+
1816+
def __eq__(self, other: object):
1817+
if not isinstance(other, type(self)):
1818+
return NotImplemented
1819+
1820+
return self.shape == other.shape and self.offset == other.offset and np.array_equal(self.q_sequence, other.q_sequence)
1821+
1822+
def __hash__(self):
1823+
return hash(
1824+
(
1825+
type(self),
1826+
self.shape,
1827+
self.offset,
1828+
self.q_sequence.tobytes() if self.q_sequence is not None else None,
1829+
)
1830+
)

0 commit comments

Comments
 (0)