54
54
SequenceClassifierOutput )
55
55
from transformers .models .bert .modeling_bert import BertPreTrainedModel
56
56
57
+ IMPL_USE_FLASH2 = False
57
58
try :
58
- import flash_attn_triton as flash_attn_triton
59
- flash_attn_qkvpacked_func = flash_attn_triton .flash_attn_qkvpacked_func
59
+ import importlib
60
+
61
+ from flash_attn import flash_attn_qkvpacked_func
62
+ installed_version = importlib .metadata .version ('flash_attn' )
63
+ if installed_version < '2.4.2' :
64
+ raise ImportError ('newer version of flash_attn required (>= 2.4.2)' )
65
+ IMPL_USE_FLASH2 = True
60
66
except ImportError as e :
61
- flash_attn_qkvpacked_func = None
67
+ warnings .warn (
68
+ f'Failed to import flash_attn. Will try to import triton implementation: { e } ' ,
69
+ stacklevel = 2 )
70
+ try :
71
+ import flash_attn_triton as flash_attn_triton
72
+ flash_attn_qkvpacked_func = flash_attn_triton .flash_attn_qkvpacked_func
73
+ except ImportError as e :
74
+ flash_attn_qkvpacked_func = None
75
+ warnings .warn (f'Failed to import flash_attn_triton as a fallback: { e } ' ,
76
+ stacklevel = 2 )
62
77
63
78
logger = logging .getLogger (__name__ )
64
79
@@ -183,7 +198,8 @@ def __init__(self, config):
183
198
184
199
def forward (self , hidden_states : torch .Tensor , cu_seqlens : torch .Tensor ,
185
200
max_seqlen_in_batch : int , indices : torch .Tensor ,
186
- attn_mask : torch .Tensor , bias : torch .Tensor ) -> torch .Tensor :
201
+ attn_mask : torch .Tensor , bias : torch .Tensor ,
202
+ slopes : torch .Tensor ) -> torch .Tensor :
187
203
"""Perform self-attention.
188
204
189
205
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
@@ -201,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
201
217
indices: (total_nnz,)
202
218
attn_mask: (batch, max_seqlen_in_batch)
203
219
bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
220
+ slopes: (heads) or (batch, heads)
204
221
205
222
Returns:
206
223
attention: (total_nnz, dim)
@@ -213,7 +230,8 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
213
230
'b s (t h d) -> b s t h d' ,
214
231
t = 3 ,
215
232
h = self .num_attention_heads )
216
- if self .p_dropout or flash_attn_qkvpacked_func is None :
233
+ if (not IMPL_USE_FLASH2 and
234
+ self .p_dropout ) or flash_attn_qkvpacked_func is None :
217
235
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
218
236
q = qkv [:, :, 0 , :, :].permute (0 , 2 , 1 , 3 ) # b h s d
219
237
k = qkv [:, :, 1 , :, :].permute (0 , 2 , 3 , 1 ) # b h d s
@@ -226,19 +244,41 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
226
244
attention = torch .matmul (attention_probs , v ).permute (0 , 2 , 1 ,
227
245
3 ) # b s h d
228
246
else :
229
- # Triton implementation only supports 0 attention dropout
230
- convert_dtype = qkv .dtype not in [torch .float16 , torch .bfloat16 ]
231
- if convert_dtype :
232
- # Triton implementation only supports fp16 and bf16
233
- orig_dtype = qkv .dtype
234
- qkv = qkv .to (torch .float16 )
235
- bias_dtype = bias .dtype
236
- bias = bias .to (torch .float16 )
237
- attention = flash_attn_qkvpacked_func (qkv , bias )
238
- attention = attention .to (orig_dtype )
239
- bias = bias .to (bias_dtype )
247
+ if IMPL_USE_FLASH2 :
248
+ assert 1 <= len (slopes .shape ) <= 2 , f'{ slopes = } '
249
+ assert slopes .shape [
250
+ - 1 ] == self .num_attention_heads , f'{ slopes = } '
251
+
252
+ # Triton implementation only supports 0 attention dropout
253
+ convert_dtype = qkv .dtype not in [torch .float16 , torch .bfloat16 ]
254
+ if convert_dtype :
255
+ # Triton implementation only supports fp16 and bf16
256
+ orig_dtype = qkv .dtype
257
+ qkv = qkv .to (torch .float16 )
258
+ bias_dtype = bias .dtype
259
+ bias = bias .to (torch .float16 )
260
+
261
+ attention = flash_attn_qkvpacked_func (
262
+ qkv , dropout_p = self .p_dropout , alibi_slopes = slopes )
263
+ attention = attention .to (orig_dtype )
264
+ bias = bias .to (bias_dtype )
265
+ else :
266
+ attention = flash_attn_qkvpacked_func (
267
+ qkv , dropout_p = self .p_dropout , alibi_slopes = slopes )
240
268
else :
241
- attention = flash_attn_qkvpacked_func (qkv , bias )
269
+ # Triton implementation only supports 0 attention dropout
270
+ convert_dtype = qkv .dtype not in [torch .float16 , torch .bfloat16 ]
271
+ if convert_dtype :
272
+ # Triton implementation only supports fp16 and bf16
273
+ orig_dtype = qkv .dtype
274
+ qkv = qkv .to (torch .float16 )
275
+ bias_dtype = bias .dtype
276
+ bias = bias .to (torch .float16 )
277
+ attention = flash_attn_qkvpacked_func (qkv , bias )
278
+ attention = attention .to (orig_dtype )
279
+ bias = bias .to (bias_dtype )
280
+ else :
281
+ attention = flash_attn_qkvpacked_func (qkv , bias )
242
282
243
283
# attn_mask is 1 for attend and 0 for don't
244
284
attention = bert_padding_module .unpad_input_only (
@@ -291,6 +331,7 @@ def forward(
291
331
indices : Optional [torch .Tensor ] = None ,
292
332
attn_mask : Optional [torch .Tensor ] = None ,
293
333
bias : Optional [torch .Tensor ] = None ,
334
+ slopes : Optional [torch .Tensor ] = None ,
294
335
) -> torch .Tensor :
295
336
"""Forward pass for scaled self-attention without padding.
296
337
@@ -303,9 +344,11 @@ def forward(
303
344
indices: None or (total_nnz,)
304
345
attn_mask: None or (batch, max_seqlen_in_batch)
305
346
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
347
+ slopes: None or (batch, heads) or (heads,)
306
348
"""
349
+ assert (bias is None ) == (slopes is None ), f'{ bias = } , { slopes = } '
307
350
self_output = self .self (input_tensor , cu_seqlens , max_s , indices ,
308
- attn_mask , bias )
351
+ attn_mask , bias , slopes )
309
352
if subset_idx is not None :
310
353
return self .output (
311
354
bert_padding_module .index_first_axis (self_output , subset_idx ),
@@ -379,6 +422,7 @@ def forward(
379
422
indices : Optional [torch .Tensor ] = None ,
380
423
attn_mask : Optional [torch .Tensor ] = None ,
381
424
bias : Optional [torch .Tensor ] = None ,
425
+ slopes : Optional [torch .Tensor ] = None ,
382
426
) -> torch .Tensor :
383
427
"""Forward pass for a BERT layer, including both attention and MLP.
384
428
@@ -391,9 +435,12 @@ def forward(
391
435
indices: None or (total_nnz,)
392
436
attn_mask: None or (batch, max_seqlen_in_batch)
393
437
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
438
+ slopes: None or (batch, heads) or (heads,)
394
439
"""
440
+ assert (bias is None ) == (slopes is None ), f'{ bias = } , { slopes = } '
395
441
attention_output = self .attention (hidden_states , cu_seqlens , seqlen ,
396
- subset_idx , indices , attn_mask , bias )
442
+ subset_idx , indices , attn_mask , bias ,
443
+ slopes )
397
444
layer_output = self .mlp (attention_output )
398
445
return layer_output
399
446
@@ -463,6 +510,7 @@ def get_slopes_power_of_2(n_heads: int) -> List[float]:
463
510
relative_position = relative_position .unsqueeze (0 ).expand (
464
511
n_heads , - 1 , - 1 )
465
512
slopes = torch .Tensor (_get_alibi_head_slopes (n_heads )).to (device )
513
+ self .slopes = slopes
466
514
alibi = slopes .unsqueeze (1 ).unsqueeze (1 ) * - relative_position
467
515
# [1, n_heads, max_token_length, max_token_length]
468
516
alibi = alibi .unsqueeze (0 )
@@ -504,6 +552,7 @@ def forward(
504
552
elif self .alibi .device != hidden_states .device :
505
553
# Device catch-up
506
554
self .alibi = self .alibi .to (hidden_states .device )
555
+ self .slopes = self .slopes .to (hidden_states .device )
507
556
alibi_bias = self .alibi [:, :, :seqlen , :seqlen ]
508
557
attn_bias = extended_attention_mask [:, :, :seqlen , :seqlen ]
509
558
alibi_attn_mask = attn_bias + alibi_bias
@@ -517,7 +566,8 @@ def forward(
517
566
None ,
518
567
indices ,
519
568
attn_mask = attention_mask ,
520
- bias = alibi_attn_mask )
569
+ bias = alibi_attn_mask ,
570
+ slopes = self .slopes )
521
571
if output_all_encoded_layers :
522
572
all_encoder_layers .append (hidden_states )
523
573
# Pad inputs and mask. It will insert back zero-padded tokens.
@@ -536,7 +586,8 @@ def forward(
536
586
None ,
537
587
indices ,
538
588
attn_mask = attention_mask ,
539
- bias = alibi_attn_mask )
589
+ bias = alibi_attn_mask ,
590
+ slopes = self .slopes )
540
591
if output_all_encoded_layers :
541
592
all_encoder_layers .append (hidden_states )
542
593
subset_idx = torch .nonzero (subset_mask [attention_mask_bool ],
@@ -547,7 +598,8 @@ def forward(
547
598
subset_idx = subset_idx ,
548
599
indices = indices ,
549
600
attn_mask = attention_mask ,
550
- bias = alibi_attn_mask )
601
+ bias = alibi_attn_mask ,
602
+ slopes = self .slopes )
551
603
552
604
if not output_all_encoded_layers :
553
605
all_encoder_layers .append (hidden_states )
0 commit comments