42
42
from MaxText .layers import linears
43
43
from MaxText .layers import quantizations
44
44
45
+ from MaxText import max_logging
46
+ from MaxText import max_utils
47
+ import numpy as np
48
+
49
+ partial = functools .partial
45
50
46
51
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
47
52
# pytype: disable=attribute-error
@@ -75,6 +80,7 @@ class AttentionType(enum.Enum):
75
80
DECODE_BATCH = common_types .DECODE_BATCH
76
81
DECODE_LENGTH = common_types .DECODE_LENGTH
77
82
LENGTH = common_types .LENGTH
83
+ Q_LENGTH = common_types .Q_LENGTH
78
84
KV_LENGTH = common_types .KV_LENGTH
79
85
HEAD = common_types .HEAD
80
86
EMBED = common_types .EMBED
@@ -290,7 +296,12 @@ class AttentionOp(nn.Module):
290
296
float32_qk_product : bool = False
291
297
max_prefill_predict_length : int = - 1
292
298
float32_logits : bool = False
293
- flash_axis_names : AxisNames = (BATCH , HEAD , LENGTH , D_KV )
299
+ flash_axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV )
300
+ flash_axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV )
301
+ flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
302
+ prefill_cache_logical_axis_names : AxisNames = (CACHE_BATCH_PREFILL , CACHE_SEQUENCE , CACHE_HEADS , CACHE_KV )
303
+ cache_logical_axis_names : AxisNames = (CACHE_BATCH , CACHE_SEQUENCE , CACHE_HEADS , CACHE_KV )
304
+ cache_scale_logical_axis_names : AxisNames = (CACHE_SCALE_BATCH , CACHE_SCALE_SEQUENCE , CACHE_SCALE_HEADS , CACHE_SCALE_KV )
294
305
ragged_qkv_axis_names : AxisNames = (CACHE_BATCH , CACHE_HEADS , CACHE_SEQUENCE , CACHE_KV )
295
306
ragged_lengths_names : AxisNames = (CACHE_BATCH ,)
296
307
compute_axis_order : AxisIdxes = (0 , 1 , 2 , 3 )
@@ -557,15 +568,25 @@ def tpu_flash_attention(
557
568
attn_logits_soft_cap : float | None = None ,
558
569
) -> Array :
559
570
"""TPU Flash Attention."""
571
+
572
+ cp_size = self .mesh .shape ["context" ]
573
+ load_balanced_context_parallel = self .config .context_parallel_load_balance
574
+
560
575
# Transpose to ('batch', 'heads', 'length', 'kv')
561
576
query = jnp .transpose (query , axes = (0 , 2 , 1 , 3 ))
562
577
key = jnp .transpose (key , axes = (0 , 2 , 1 , 3 ))
563
578
value = jnp .transpose (value , axes = (0 , 2 , 1 , 3 ))
564
-
579
+ segment_axis_names_q = None
580
+ segment_axis_names_kv = None
565
581
if decoder_segment_ids is not None :
566
- decoder_segment_ids = splash_attention_kernel .SegmentIds (decoder_segment_ids , decoder_segment_ids )
567
- axis_names = nn .logical_to_mesh_axes (self .flash_axis_names )
568
- segment_axis_names = nn .logical_to_mesh_axes ((BATCH , "activation_length_no_heads" ))
582
+ segment_axis_names_q = nn .logical_to_mesh_axes ((BATCH , Q_LENGTH ))
583
+ segment_axis_names_kv = nn .logical_to_mesh_axes ((BATCH , KV_LENGTH ))
584
+ axis_names_splash_kernel = nn .logical_to_mesh_axes (self .flash_axis_names_splash_kernel )
585
+ axis_names_q = nn .logical_to_mesh_axes (self .flash_axis_names_q )
586
+ axis_names_kv = nn .logical_to_mesh_axes (self .flash_axis_names_kv )
587
+ max_logging .log (f"axis_names_q: { axis_names_q } " )
588
+ max_logging .log (f"axis_names_kv: { axis_names_kv } " )
589
+ max_logging .log (f"axis_names_splash_kernel: { axis_names_splash_kernel } " )
569
590
570
591
global_block_q = self .config .sa_block_q
571
592
global_block_kv = self .config .sa_block_kv
@@ -580,40 +601,46 @@ def tpu_flash_attention(
580
601
global_k_layout = self .config .sa_k_layout
581
602
global_v_layout = self .config .sa_v_layout
582
603
583
- @functools .partial (
584
- shard_map ,
585
- mesh = self .mesh ,
586
- in_specs = (
587
- axis_names ,
588
- axis_names ,
589
- axis_names ,
590
- segment_axis_names ,
591
- ),
592
- out_specs = axis_names ,
593
- check_rep = False ,
604
+ devices_in_data_fsdp = self .mesh .shape ["data" ] * self .mesh .shape ["fsdp" ]
605
+ assert (query .shape [0 ] / devices_in_data_fsdp ).is_integer (), (
606
+ "Batch dimension should be shardable among the devices in data and fsdp"
607
+ " axis"
608
+ f" got { query .shape [0 ]= } /{ devices_in_data_fsdp = } "
594
609
)
595
- def wrap_flash_attention (query , key , value , decoder_segment_ids ):
596
- if decoder_segment_ids is not None :
597
- assert (
598
- query .shape [2 ] == decoder_segment_ids .q .shape [1 ]
599
- ), "Sharding along sequence dimension not allowed in tpu kernel attention"
600
- block_sizes = splash_attention_kernel .BlockSizes (
601
- block_q = min (global_block_q , query .shape [2 ]),
602
- block_kv = min (global_block_kv , key .shape [2 ]),
603
- block_kv_compute = min (global_block_kv_compute , key .shape [2 ]),
604
- block_q_dkv = min (global_block_q_dkv , query .shape [2 ]),
605
- block_kv_dkv = min (global_block_kv_dkv , key .shape [2 ]),
606
- block_kv_dkv_compute = min (global_block_kv_dkv_compute , query .shape [2 ]),
607
- block_q_dq = None if global_use_fused_bwd_kernel else min (global_block_q_dq , query .shape [2 ]),
608
- block_kv_dq = None if global_use_fused_bwd_kernel else min (global_block_kv_dq , query .shape [2 ]),
609
- use_fused_bwd_kernel = global_use_fused_bwd_kernel ,
610
- q_layout = splash_attention_kernel .QKVLayout [global_q_layout ],
611
- k_layout = splash_attention_kernel .QKVLayout [global_k_layout ],
612
- v_layout = splash_attention_kernel .QKVLayout [global_v_layout ],
613
- )
614
610
615
- mask = splash_attention_mask .CausalMask (shape = (query .shape [2 ], query .shape [2 ]))
611
+ # create_splash_attention kernel
612
+ block_sizes = splash_attention_kernel .BlockSizes (
613
+ block_q = min (global_block_q , query .shape [2 ]),
614
+ block_kv = min (global_block_kv , key .shape [2 ]),
615
+ block_kv_compute = min (global_block_kv_compute , key .shape [2 ]),
616
+ block_q_dkv = min (global_block_q_dkv , query .shape [2 ]),
617
+ block_kv_dkv = min (global_block_kv_dkv , key .shape [2 ]),
618
+ block_kv_dkv_compute = min (global_block_kv_dkv_compute , query .shape [2 ]),
619
+ block_q_dq = None if global_use_fused_bwd_kernel else min (global_block_q_dq , query .shape [2 ]),
620
+ block_kv_dq = None if global_use_fused_bwd_kernel else min (global_block_kv_dq , query .shape [2 ]),
621
+ use_fused_bwd_kernel = global_use_fused_bwd_kernel ,
622
+ q_layout = splash_attention_kernel .QKVLayout [global_q_layout ],
623
+ k_layout = splash_attention_kernel .QKVLayout [global_k_layout ],
624
+ v_layout = splash_attention_kernel .QKVLayout [global_v_layout ],
625
+ )
626
+
627
+ mask_shape = (self .config .max_target_length , self .config .max_target_length )
628
+ mask = splash_attention_mask .CausalMask (shape = mask_shape )
629
+
630
+ # Create LoadBalancedCausalMask if cp and load_balancing
631
+ if cp_size > 1 and load_balanced_context_parallel :
632
+ mask = LoadBalancedCausalMask (shape = mask_shape , cp_size = cp_size )
616
633
634
+ # TODO: figure out local_sliding attention + load_balancing, default is global
635
+ # Apply local masking if local sliding attention is enabled.
636
+ if self .attention_type == AttentionType .LOCAL_SLIDING :
637
+ if self .sliding_window_size is None :
638
+ raise ValueError ("Sliding_window_size must be set if Local Sliding attention type" )
639
+ mask &= splash_attention_mask .LocalMask (
640
+ shape = (query .shape [2 ], key .shape [2 ]),
641
+ window_size = (self .sliding_window_size , self .sliding_window_size ),
642
+ offset = 0 ,
643
+ )
617
644
# Apply local masking if local sliding attention is enabled.
618
645
if self .attention_type == AttentionType .LOCAL_SLIDING :
619
646
if self .sliding_window_size is None :
@@ -629,26 +656,100 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids):
629
656
630
657
mask &= ChunkedCausalMask (shape = (query .shape [2 ], key .shape [2 ]), chunk_size = self .chunk_attn_window_size )
631
658
632
- # Create multi-head mask
633
- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
659
+ # Create multi-head mask
660
+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
661
+
662
+ # Create the splash attention kernel object separately, jit it for performance
663
+ @partial (
664
+ jax .jit ,
665
+ static_argnames = [
666
+ "multi_head_mask" ,
667
+ "shard_head_size" ,
668
+ ],
669
+ )
670
+ def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
634
671
splash_kernel = splash_attention_kernel .make_splash_mha (
635
672
mask = multi_head_mask ,
636
- head_shards = 1 ,
637
- q_seq_shards = 1 ,
673
+ head_shards = shard_head_size , # the size of the axis if sharding over heads
674
+ q_seq_shards = cp_size , # axis for sequence sharding
638
675
block_sizes = block_sizes ,
639
676
attn_logits_soft_cap = attn_logits_soft_cap ,
640
677
)
678
+ return splash_kernel
641
679
642
- return jax .vmap (splash_kernel )(query , key , value , segment_ids = decoder_segment_ids )
680
+ logical_axis_rules_head = np .array (
681
+ [self .mesh .shape [physical_axes ] for physical_axes in dict (self .config .logical_axis_rules )[HEAD ]]
682
+ )
683
+ shard_head_size = np .prod (logical_axis_rules_head )
684
+ splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
685
+ named_sharding = jax .sharding .NamedSharding (self .mesh , axis_names_splash_kernel )
686
+ segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
687
+
688
+ # Now call the function wrap_flash_attention which does the actual computation.
689
+ # The splash kernel is passed as a parameter to the function. Since we have the shard map
690
+ # decorating the wrap_flash_attention function, the data will be correctly sharded
691
+ # meaning q will be sharded over sequence aka context length but K and V will be duplicated
692
+ # The shardings are specified in the in_specs and out_specs of the shard_map decorator:
693
+ # 'segment_axis_names_q' maps to ['activation_q_length', ['context']] meaning that q is sharded over the context axis
694
+ # 'segment_axis_names_kv' maps to ['activation_kv_length', []] meaning that K and V are not sharded
695
+ # splash_kernel is sharded over (HEAD, LENGTH)
696
+ @functools .partial (
697
+ shard_map ,
698
+ mesh = self .mesh ,
699
+ in_specs = (
700
+ axis_names_q ,
701
+ axis_names_kv ,
702
+ axis_names_kv ,
703
+ segment_axis_names_q ,
704
+ segment_axis_names_kv ,
705
+ segment_axis_names_splash_kernel ,
706
+ None , # no sharding for cp_size
707
+ None , # no sharding for load_balanced_context_parallel
708
+ ),
709
+ out_specs = axis_names_q ,
710
+ check_rep = False ,
711
+ )
712
+ def wrap_flash_attention (
713
+ query ,
714
+ key ,
715
+ value ,
716
+ decoder_segment_ids_q ,
717
+ decoder_segment_ids_kv ,
718
+ splash_kernel ,
719
+ cp_size ,
720
+ load_balanced_context_parallel ,
721
+ ):
722
+ # If load_balanced_context_parallel is enabled, reorder the key and value tensors
723
+ # to ensure that they are contiguous in memory.
724
+ # This is necessary for the splash attention kernel to work correctly because it expects
725
+ # the K and V to be contiguous. Note that K and V are not sharded over the sequence aka context axis
726
+ # This was we get the unsharded unpermuted key and value tensors
727
+ if cp_size > 1 and load_balanced_context_parallel :
728
+ key = max_utils .reorder_sequence (tensor = key , cp_size = cp_size , seq_dim = 2 , to_contiguous = True )
729
+ value = max_utils .reorder_sequence (tensor = value , cp_size = cp_size , seq_dim = 2 , to_contiguous = True )
730
+ decoder_segment_ids_unpermuted = max_utils .reorder_sequence (
731
+ tensor = decoder_segment_ids_kv , cp_size = cp_size , seq_dim = 1 , to_contiguous = True
732
+ )
643
733
644
- devices_in_data_fsdp = self .mesh .shape ["data" ] * self .mesh .shape ["fsdp" ]
645
- assert (query .shape [0 ] / devices_in_data_fsdp ).is_integer (), (
646
- "Batch dimension should be shardable among the devices in data and fsdp"
647
- " axis"
648
- f" got { query .shape [0 ]= } /{ devices_in_data_fsdp = } "
734
+ if decoder_segment_ids_q is not None :
735
+ if cp_size > 1 and load_balanced_context_parallel :
736
+ decoder_segment_ids_tuple = splash_attention_kernel .SegmentIds (
737
+ decoder_segment_ids_q , decoder_segment_ids_unpermuted
738
+ )
739
+ else :
740
+ decoder_segment_ids_tuple = splash_attention_kernel .SegmentIds (decoder_segment_ids_q , decoder_segment_ids_q )
741
+ else :
742
+ decoder_segment_ids_tuple = None
743
+ attention_output = jax .vmap (splash_kernel )(query , key , value , segment_ids = decoder_segment_ids_tuple )
744
+
745
+ return attention_output
746
+
747
+ x = wrap_flash_attention (
748
+ query , key , value , decoder_segment_ids , decoder_segment_ids , splash_kernel , cp_size , load_balanced_context_parallel
649
749
)
650
- x = wrap_flash_attention ( query , key , value , decoder_segment_ids )
750
+
651
751
x = jnp .transpose (x , axes = (0 , 2 , 1 , 3 ))
752
+
652
753
return x
653
754
654
755
def cudnn_flash_attention (
@@ -1673,3 +1774,57 @@ def __call__(
1673
1774
out = nn .with_logical_constraint (out , self .out_axis_names )
1674
1775
out = self .out_projection (inputs_q .shape [- 1 ], out )
1675
1776
return out
1777
+
1778
+
1779
+ class LoadBalancedCausalMask (splash_attention_mask ._ComputableMask ):
1780
+ """Lazy causal mask, prevents the model from attending to future tokens.
1781
+ Attributes:
1782
+ offset: Offset of q start wrt kv. A positive offset shifts the bottom
1783
+ triangle upward, a negative one shifts it downward. A negative offset
1784
+ makes the first 'offset' rows of the attention matrix all 0s which leads
1785
+ to undefined softmax.
1786
+ """
1787
+
1788
+ offset : int
1789
+ shape : tuple [int , int ]
1790
+ cp_size : int
1791
+
1792
+ def __init__ (self , shape : tuple [int , int ], offset : int = 0 , shard_count : int = 1 , cp_size : int = 4 ):
1793
+ self .offset = offset
1794
+
1795
+ def causal_mask_function (q_ids , kv_ids ):
1796
+ if self .offset == 0 :
1797
+ return q_ids >= kv_ids
1798
+ else :
1799
+ return q_ids + self .offset >= kv_ids
1800
+
1801
+ arr = np .arange (shape [0 ])
1802
+ # we reorder the mask to be load balanced following the same approach as
1803
+ # used to reorder the input tokens
1804
+ out = max_utils .reorder_mask_load_balancing (arr [None , :, None , None ], cp_size , seq_dim = 1 )
1805
+ q_sequence = out [0 , :, 0 , 0 ]
1806
+
1807
+ mask_function = causal_mask_function
1808
+
1809
+ super ().__init__ (
1810
+ shape = shape ,
1811
+ mask_function = mask_function ,
1812
+ shard_count = shard_count ,
1813
+ )
1814
+ self .q_sequence = q_sequence
1815
+
1816
+ def __eq__ (self , other : object ):
1817
+ if not isinstance (other , type (self )):
1818
+ return NotImplemented
1819
+
1820
+ return self .shape == other .shape and self .offset == other .offset and np .array_equal (self .q_sequence , other .q_sequence )
1821
+
1822
+ def __hash__ (self ):
1823
+ return hash (
1824
+ (
1825
+ type (self ),
1826
+ self .shape ,
1827
+ self .offset ,
1828
+ self .q_sequence .tobytes () if self .q_sequence is not None else None ,
1829
+ )
1830
+ )
0 commit comments