Skip to content

Commit b809a7b

Browse files
committed
Modernize MosaicBERT
1 parent 7003793 commit b809a7b

File tree

6 files changed

+91
-44
lines changed

6 files changed

+91
-44
lines changed

examples/benchmarks/bert/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ def main(cfg: DictConfig,
246246
load_path=cfg.get('load_path', None),
247247
load_weights_only=cfg.get('load_weights_only', False),
248248
python_log_level=cfg.get('python_log_level', None),
249-
)
249+
autoresume=cfg.get('autoresume', None),
250+
fsdp_config=cfg.get('fsdp_config', None),
251+
compile_config=cfg.get('compile_config', None))
250252

251253
print('Logging config...')
252254
log_config(cfg)
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
einops==0.5.0
2-
torch==1.13.1
3-
mosaicml[nlp,wandb]>=0.14.0,<0.15
4-
mosaicml-streaming==0.4.1
5-
omegaconf==2.2.3
6-
transformers==4.28.1
2+
torch==2.1.1
3+
composer[nlp,wandb]>=0.17.0,<0.18
4+
mosaicml-streaming<=0.7
5+
omegaconf==2.3.0
6+
transformers==4.35.2
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
einops==0.5.0
2-
torch==1.13.1
3-
mosaicml[nlp,wandb]>=0.14.0,<0.15
4-
mosaicml-streaming==0.4.1
5-
omegaconf==2.2.3
6-
transformers==4.28.1
2+
torch==2.1.1
3+
composer[nlp,wandb]>=0.17.0,<0.18
4+
mosaicml-streaming<=0.7
5+
omegaconf==2.3.0
6+
transformers==4.35.2
7+
# need a newer version of FA2
8+
flash_attn>=2.4.2
79
# need a newer version of triton
8-
triton==2.0.0.dev20221103
10+
#triton==2.0.0.dev20221103

examples/benchmarks/bert/src/bert_layers.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,26 @@
5454
SequenceClassifierOutput)
5555
from transformers.models.bert.modeling_bert import BertPreTrainedModel
5656

57+
IMPL_USE_FLASH2 = False
5758
try:
58-
import flash_attn_triton as flash_attn_triton
59-
flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
59+
import importlib
60+
61+
from flash_attn import flash_attn_qkvpacked_func
62+
installed_version = importlib.metadata.version('flash_attn')
63+
if installed_version < '2.4.2':
64+
raise ImportError('newer version of flash_attn required (>= 2.4.2)')
65+
IMPL_USE_FLASH2 = True
6066
except ImportError as e:
61-
flash_attn_qkvpacked_func = None
67+
warnings.warn(
68+
f'Failed to import flash_attn. Will try to import triton implementation: {e}',
69+
stacklevel=2)
70+
try:
71+
import flash_attn_triton as flash_attn_triton
72+
flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
73+
except ImportError as e:
74+
flash_attn_qkvpacked_func = None
75+
warnings.warn(f'Failed to import flash_attn_triton as a fallback: {e}',
76+
stacklevel=2)
6277

6378
logger = logging.getLogger(__name__)
6479

@@ -183,7 +198,8 @@ def __init__(self, config):
183198

184199
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
185200
max_seqlen_in_batch: int, indices: torch.Tensor,
186-
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
201+
attn_mask: torch.Tensor, bias: torch.Tensor,
202+
slopes: torch.Tensor) -> torch.Tensor:
187203
"""Perform self-attention.
188204
189205
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
@@ -201,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
201217
indices: (total_nnz,)
202218
attn_mask: (batch, max_seqlen_in_batch)
203219
bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
220+
slopes: (heads) or (batch, heads)
204221
205222
Returns:
206223
attention: (total_nnz, dim)
@@ -213,7 +230,8 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
213230
'b s (t h d) -> b s t h d',
214231
t=3,
215232
h=self.num_attention_heads)
216-
if self.p_dropout or flash_attn_qkvpacked_func is None:
233+
if (not IMPL_USE_FLASH2 and
234+
self.p_dropout) or flash_attn_qkvpacked_func is None:
217235
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
218236
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
219237
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
@@ -226,19 +244,41 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
226244
attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
227245
3) # b s h d
228246
else:
229-
# Triton implementation only supports 0 attention dropout
230-
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
231-
if convert_dtype:
232-
# Triton implementation only supports fp16 and bf16
233-
orig_dtype = qkv.dtype
234-
qkv = qkv.to(torch.float16)
235-
bias_dtype = bias.dtype
236-
bias = bias.to(torch.float16)
237-
attention = flash_attn_qkvpacked_func(qkv, bias)
238-
attention = attention.to(orig_dtype)
239-
bias = bias.to(bias_dtype)
247+
if IMPL_USE_FLASH2:
248+
assert 1 <= len(slopes.shape) <= 2, f'{slopes=}'
249+
assert slopes.shape[
250+
-1] == self.num_attention_heads, f'{slopes=}'
251+
252+
# Triton implementation only supports 0 attention dropout
253+
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
254+
if convert_dtype:
255+
# Triton implementation only supports fp16 and bf16
256+
orig_dtype = qkv.dtype
257+
qkv = qkv.to(torch.float16)
258+
bias_dtype = bias.dtype
259+
bias = bias.to(torch.float16)
260+
261+
attention = flash_attn_qkvpacked_func(
262+
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
263+
attention = attention.to(orig_dtype)
264+
bias = bias.to(bias_dtype)
265+
else:
266+
attention = flash_attn_qkvpacked_func(
267+
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
240268
else:
241-
attention = flash_attn_qkvpacked_func(qkv, bias)
269+
# Triton implementation only supports 0 attention dropout
270+
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
271+
if convert_dtype:
272+
# Triton implementation only supports fp16 and bf16
273+
orig_dtype = qkv.dtype
274+
qkv = qkv.to(torch.float16)
275+
bias_dtype = bias.dtype
276+
bias = bias.to(torch.float16)
277+
attention = flash_attn_qkvpacked_func(qkv, bias)
278+
attention = attention.to(orig_dtype)
279+
bias = bias.to(bias_dtype)
280+
else:
281+
attention = flash_attn_qkvpacked_func(qkv, bias)
242282

243283
# attn_mask is 1 for attend and 0 for don't
244284
attention = bert_padding_module.unpad_input_only(
@@ -291,6 +331,7 @@ def forward(
291331
indices: Optional[torch.Tensor] = None,
292332
attn_mask: Optional[torch.Tensor] = None,
293333
bias: Optional[torch.Tensor] = None,
334+
slopes: Optional[torch.Tensor] = None,
294335
) -> torch.Tensor:
295336
"""Forward pass for scaled self-attention without padding.
296337
@@ -303,9 +344,11 @@ def forward(
303344
indices: None or (total_nnz,)
304345
attn_mask: None or (batch, max_seqlen_in_batch)
305346
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
347+
slopes: None or (batch, heads) or (heads,)
306348
"""
349+
assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}'
307350
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
308-
attn_mask, bias)
351+
attn_mask, bias, slopes)
309352
if subset_idx is not None:
310353
return self.output(
311354
bert_padding_module.index_first_axis(self_output, subset_idx),
@@ -379,6 +422,7 @@ def forward(
379422
indices: Optional[torch.Tensor] = None,
380423
attn_mask: Optional[torch.Tensor] = None,
381424
bias: Optional[torch.Tensor] = None,
425+
slopes: Optional[torch.Tensor] = None,
382426
) -> torch.Tensor:
383427
"""Forward pass for a BERT layer, including both attention and MLP.
384428
@@ -391,9 +435,12 @@ def forward(
391435
indices: None or (total_nnz,)
392436
attn_mask: None or (batch, max_seqlen_in_batch)
393437
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
438+
slopes: None or (batch, heads) or (heads,)
394439
"""
440+
assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}'
395441
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
396-
subset_idx, indices, attn_mask, bias)
442+
subset_idx, indices, attn_mask, bias,
443+
slopes)
397444
layer_output = self.mlp(attention_output)
398445
return layer_output
399446

@@ -463,6 +510,7 @@ def get_slopes_power_of_2(n_heads: int) -> List[float]:
463510
relative_position = relative_position.unsqueeze(0).expand(
464511
n_heads, -1, -1)
465512
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
513+
self.slopes = slopes
466514
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
467515
# [1, n_heads, max_token_length, max_token_length]
468516
alibi = alibi.unsqueeze(0)
@@ -504,6 +552,7 @@ def forward(
504552
elif self.alibi.device != hidden_states.device:
505553
# Device catch-up
506554
self.alibi = self.alibi.to(hidden_states.device)
555+
self.slopes = self.slopes.to(hidden_states.device)
507556
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
508557
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
509558
alibi_attn_mask = attn_bias + alibi_bias
@@ -517,7 +566,8 @@ def forward(
517566
None,
518567
indices,
519568
attn_mask=attention_mask,
520-
bias=alibi_attn_mask)
569+
bias=alibi_attn_mask,
570+
slopes=self.slopes)
521571
if output_all_encoded_layers:
522572
all_encoder_layers.append(hidden_states)
523573
# Pad inputs and mask. It will insert back zero-padded tokens.
@@ -536,7 +586,8 @@ def forward(
536586
None,
537587
indices,
538588
attn_mask=attention_mask,
539-
bias=alibi_attn_mask)
589+
bias=alibi_attn_mask,
590+
slopes=self.slopes)
540591
if output_all_encoded_layers:
541592
all_encoder_layers.append(hidden_states)
542593
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
@@ -547,7 +598,8 @@ def forward(
547598
subset_idx=subset_idx,
548599
indices=indices,
549600
attn_mask=attention_mask,
550-
bias=alibi_attn_mask)
601+
bias=alibi_attn_mask,
602+
slopes=self.slopes)
551603

552604
if not output_all_encoded_layers:
553605
all_encoder_layers.append(hidden_states)

examples/benchmarks/bert/src/mosaic_bert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ def create_mosaic_bert_mlm(pretrained_model_name: str = 'bert-base-uncased',
119119
pretrained_model_name)
120120

121121
metrics = [
122-
LanguageCrossEntropy(ignore_index=-100,
123-
vocab_size=model.config.vocab_size),
122+
LanguageCrossEntropy(ignore_index=-100),
124123
MaskedAccuracy(ignore_index=-100)
125124
]
126125

examples/benchmarks/bert/src/text_data.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ class StreamingTextDataset(StreamingDataset):
6969
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
7070
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
7171
`False``.
72-
keep_raw (bool): Whether to keep or delete the decompressed form (or only form)
73-
of shards after all their samples have been yielded this epoch. If ``False``, keep iff
74-
remote is local or no remote and no compression. Defaults to ``True``.
7572
samples_per_epoch (int, optional): Provide this field iff you are weighting sub-datasets
7673
proportionally. Defaults to ``None``.
7774
predownload (int, optional): Target number of samples ahead to download the shards of while
@@ -99,7 +96,6 @@ def __init__(self,
9996
download_timeout: float = 60,
10097
validate_hash: Optional[str] = None,
10198
keep_zip: bool = False,
102-
keep_raw: bool = True,
10399
samples_per_epoch: Optional[int] = None,
104100
predownload: int = 100_000,
105101
partition_algo: str = 'orig',
@@ -140,7 +136,6 @@ def __init__(self,
140136
download_timeout=download_timeout,
141137
validate_hash=validate_hash,
142138
keep_zip=keep_zip,
143-
keep_raw=keep_raw,
144139
samples_per_epoch=samples_per_epoch,
145140
predownload=predownload,
146141
partition_algo=partition_algo,
@@ -266,8 +261,6 @@ def build_text_dataloader(
266261
cfg.dataset.get('validate_hash', None),
267262
keep_zip=stream.get('keep_zip', None) or
268263
cfg.dataset.get('keep_zip', False),
269-
keep_raw=stream.get('keep_raw', None) or
270-
cfg.dataset.get('keep_raw', True),
271264
))
272265

273266
# build dataset potentially with streams
@@ -282,7 +275,6 @@ def build_text_dataloader(
282275
download_timeout=cfg.dataset.get('download_timeout', 60),
283276
validate_hash=cfg.dataset.get('validate_hash', None),
284277
keep_zip=cfg.dataset.get('keep_zip', False),
285-
keep_raw=cfg.dataset.get('keep_raw', True),
286278
samples_per_epoch=cfg.dataset.get('samples_per_epoch', None),
287279
predownload=cfg.dataset.get('predownload', 100_000),
288280
partition_algo=cfg.dataset.get('partition_algo', 'orig'),

0 commit comments

Comments
 (0)