Skip to content

Commit 4a96541

Browse files
authored
Add ModernBERT (#598)
1 parent d4c8d8f commit 4a96541

File tree

4 files changed

+330
-167
lines changed

4 files changed

+330
-167
lines changed

ch06/02_bonus_additional-experiments/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
7474
9. **Padding vs no padding (Row 1 vs. 14 & 15, and 16)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy. In row 16, padding is applied, but the token position is selected based on the last non-padding token. Row 16 should be mathematically similar to row 15, which uses gradient accumulation. However, due to some challenges with gradient accumulation in cases of unequal token counts, there may be small discrepancies (this is discussed in [this](https://unsloth.ai/blog/gradient) blog post).
7575
10. **Disabling the causal attention mask (Row 1 vs. 17)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
7676
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 18)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
77-
13. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.
77+
12. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.

ch06/03_bonus_imdb-classification/README.md

+102-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
# Additional Experiments Classifying the Sentiment of 50k IMDB Movie Reviews
22

3+
## Overview
4+
5+
This folder contains additional experiments to compare the (decoder-style) GPT-2 (2018) model from chapter 6 to encoder-style LLMs like [BERT (2018)](https://arxiv.org/abs/1810.04805), [RoBERTa (2019)](https://arxiv.org/abs/1907.11692), and [ModernBERT (2024)](https://arxiv.org/abs/2412.13663). Instead of using the small SPAM dataset from Chapter 6, we are using the 50k movie review dataset from IMDb ([dataset source](https://ai.stanford.edu/~amaas/data/sentiment/)) with a binary classification objective, predicting whether a reviewer liked the movie or not. This is a balanced dataset, so a random prediction should yield 50% accuracy.
6+
7+
8+
9+
10+
11+
| | Model | Test accuracy |
12+
| ----- | ---------------------------- | ------------- |
13+
| **1** | 124 M GPT-2 Baseline | 91.88% |
14+
| **2** | 340 M BERT | 90.89% |
15+
| **3** | 66 M DistilBERT | 91.40% |
16+
| **4** | 355 M RoBERTa | 92.95% |
17+
| **5** | 149 M ModernBERT Base | 93.79% |
18+
| **6** | 395 M ModernBERT Large | 95.07% |
19+
| **7** | Logistic Regression Baseline | 88.85% |
20+
21+
22+
23+
24+
25+
326
 
427
## Step 1: Install Dependencies
528

@@ -24,7 +47,10 @@ python download_prepare_dataset.py
2447
 
2548
## Step 3: Run Models
2649

27-
The 124M GPT-2 model used in the main chapter, starting with pretrained weights, and finetuning all weights:
50+
 
51+
### 1) 124 M GPT-2 Baseline
52+
53+
The 124M GPT-2 model used in chapter 6, starting with pretrained weights, and finetuning all weights:
2854

2955
```bash
3056
python train_gpt.py --trainable_layers "all" --num_epochs 1
@@ -53,6 +79,10 @@ Test accuracy: 91.88%
5379

5480
<br>
5581

82+
&nbsp;
83+
### 2) 340 M BERT
84+
85+
5686
A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model:
5787

5888
```bash
@@ -81,6 +111,9 @@ Test accuracy: 90.89%
81111

82112
<br>
83113

114+
&nbsp;
115+
### 3) 66 M DistilBERT
116+
84117
A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers:
85118

86119

@@ -110,6 +143,9 @@ Test accuracy: 91.40%
110143

111144
<br>
112145

146+
&nbsp;
147+
### 4) 355 M RoBERTa
148+
113149
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
114150

115151

@@ -133,14 +169,78 @@ Validation accuracy: 93.02%
133169
Test accuracy: 92.95%
134170
```
135171

172+
<br>
173+
174+
---
175+
176+
<br>
177+
178+
179+
&nbsp;
180+
### 5) 149 M ModernBERT Base
181+
182+
[ModernBERT (2024)](https://arxiv.org/abs/2412.13663) is an optimized reimplementation of BERT that incorporates architectural improvements like parallel residual connections and gated linear units (GLUs) to boost efficiency and performance. It maintains BERT’s original pretraining objectives while achieving faster inference and better scalability on modern hardware.
183+
184+
```
185+
Ep 1 (Step 000000): Train loss 0.699, Val loss 0.698
186+
Ep 1 (Step 000050): Train loss 0.564, Val loss 0.606
187+
...
188+
Ep 1 (Step 004300): Train loss 0.086, Val loss 0.168
189+
Ep 1 (Step 004350): Train loss 0.160, Val loss 0.131
190+
Training accuracy: 95.62% | Validation accuracy: 93.75%
191+
Training completed in 10.27 minutes.
192+
193+
Evaluating on the full datasets ...
194+
195+
Training accuracy: 95.72%
196+
Validation accuracy: 94.00%
197+
Test accuracy: 93.79%
198+
```
136199

137200
<br>
138201

139202
---
140203

141204
<br>
142205

143-
A scikit-learn logistic regression classifier as a baseline:
206+
207+
&nbsp;
208+
### 6) 395 M ModernBERT Large
209+
210+
Same as above but using the larger ModernBERT variant.
211+
212+
213+
214+
```
215+
Ep 1 (Step 000000): Train loss 0.666, Val loss 0.662
216+
Ep 1 (Step 000050): Train loss 0.548, Val loss 0.556
217+
...
218+
Ep 1 (Step 004300): Train loss 0.083, Val loss 0.115
219+
Ep 1 (Step 004350): Train loss 0.154, Val loss 0.116
220+
Training accuracy: 96.88% | Validation accuracy: 95.62%
221+
Training completed in 27.69 minutes.
222+
223+
Evaluating on the full datasets ...
224+
225+
Training accuracy: 97.04%
226+
Validation accuracy: 95.30%
227+
Test accuracy: 95.07%
228+
```
229+
230+
231+
232+
233+
234+
<br>
235+
236+
---
237+
238+
<br>
239+
240+
&nbsp;
241+
### 7) Logistic Regression Baseline
242+
243+
A scikit-learn [logistic regression](https://sebastianraschka.com/blog/2022/losses-learned-part1.html) classifier as a baseline:
144244

145245

146246
```bash

0 commit comments

Comments
 (0)