Skip to content

Commit 27d4639

Browse files
sguggerstas00
andauthored
Make gradient_checkpointing a training argument (#13657)
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <[email protected]> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <[email protected]> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Stas Bekman <[email protected]>
1 parent 75f6641 commit 27d4639

File tree

96 files changed

+531
-309
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+531
-309
lines changed

docs/source/model_doc/led.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ Tips:
4646
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
4747
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
4848
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
49-
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
50-
``config.gradient_checkpointing = True``.
49+
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
50+
``model.gradient_checkpointing_enable()``.
5151
- A notebook showing how to evaluate LED, can be accessed `here
5252
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
5353
- A notebook showing how to fine-tune LED, can be accessed `here

docs/source/performance.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Software:
5353
- Tensor Parallelism
5454
- Low-memory Optimizers
5555
- fp16/bf16 (smaller data)
56+
- Gradient checkpointing
5657

5758

5859

@@ -226,6 +227,21 @@ pytorch `autocast` which performs AMP include a caching feature, which speed thi
226227

227228
Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.)
228229

230+
231+
### Gradient Checkpointing
232+
233+
One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation.
234+
235+
This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers.
236+
237+
To activate this feature in 🤗 Transformers for models that support it, use:
238+
239+
```python
240+
model.gradient_checkpointing_enable()
241+
```
242+
or add `--gradient_checkpointing` to the Trainer arguments.
243+
244+
229245
### Batch sizes
230246

231247
One gets the most efficient performance when batch sizes and input/output neuron counts are divisible by a certain number, which typically starts at 8, but can be much higher as well. That number varies a lot depending on the specific hardware being used and the dtype of the model.

examples/pytorch/language-modeling/README.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,3 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="
174174
```
175175

176176
This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`.
177-
178-
This feature can also be used to activate gradient checkpointing by passing:
179-
```
180-
--config_overrides "gradient_checkpointing=true,use_cache=False"
181-
```

src/transformers/configuration_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import copy
2020
import json
2121
import os
22+
import warnings
2223
from typing import Any, Dict, Tuple, Union
2324

2425
from . import __version__
@@ -330,6 +331,14 @@ def __init__(self, **kwargs):
330331
# Drop the transformers version info
331332
self.transformers_version = kwargs.pop("transformers_version", None)
332333

334+
# Deal with gradient checkpointing
335+
if "gradient_checkpointing" in kwargs:
336+
warnings.warn(
337+
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
338+
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
339+
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
340+
)
341+
333342
# Additional attributes without default values
334343
for key, value in kwargs.items():
335344
try:

src/transformers/modeling_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import warnings
2121
from contextlib import contextmanager
2222
from dataclasses import dataclass
23+
from functools import partial
2324
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
2425

2526
import torch
@@ -450,6 +451,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
450451
_keys_to_ignore_on_save = None
451452

452453
is_parallelizable = False
454+
supports_gradient_checkpointing = False
453455

454456
@property
455457
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
@@ -469,6 +471,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
469471
# Save config and origin of the pretrained weights if given in model
470472
self.config = config
471473
self.name_or_path = config.name_or_path
474+
if getattr(self.config, "gradient_checkpointing", False):
475+
self.gradient_checkpointing_enable()
476+
# Remove the attribute now that is has been consumed, so it's no saved in the config.
477+
delattr(self.config, "gradient_checkpointing")
472478

473479
@classmethod
474480
def _from_config(cls, config, **kwargs):
@@ -932,6 +938,27 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
932938

933939
self.base_model._prune_heads(heads_to_prune)
934940

941+
def gradient_checkpointing_enable(self, flag: bool = True):
942+
"""
943+
Activates gradient checkpointing for the current model.
944+
945+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
946+
activations".
947+
"""
948+
if not self.supports_gradient_checkpointing:
949+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
950+
self.apply(partial(self._set_gradient_checkpointing, value=True))
951+
952+
def gradient_checkpointing_disable(self, flag: bool = True):
953+
"""
954+
Deactivates gradient checkpointing for the current model.
955+
956+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
957+
activations".
958+
"""
959+
if self.supports_gradient_checkpointing:
960+
self.apply(partial(self._set_gradient_checkpointing, value=False))
961+
935962
def save_pretrained(
936963
self,
937964
save_directory: Union[str, os.PathLike],

src/transformers/models/bart/configuration_bart.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ class BartConfig(PretrainedConfig):
8282
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
8383
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
8484
https://arxiv.org/abs/1909.11556>`__ for more details.
85-
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
86-
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
8785
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
8886
Scale embeddings by diving by sqrt(d_model).
8987
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
@@ -131,7 +129,6 @@ def __init__(
131129
init_std=0.02,
132130
classifier_dropout=0.0,
133131
scale_embedding=False,
134-
gradient_checkpointing=False,
135132
use_cache=True,
136133
num_labels=3,
137134
pad_token_id=1,
@@ -161,7 +158,6 @@ def __init__(
161158
self.classifier_dropout = classifier_dropout
162159
self.use_cache = use_cache
163160
self.num_hidden_layers = encoder_layers
164-
self.gradient_checkpointing = gradient_checkpointing
165161
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
166162

167163
super().__init__(

src/transformers/models/bart/modeling_bart.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def forward(self, hidden_states: torch.Tensor):
471471
class BartPretrainedModel(PreTrainedModel):
472472
config_class = BartConfig
473473
base_model_prefix = "model"
474+
supports_gradient_checkpointing = True
474475
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
475476

476477
def _init_weights(self, module):
@@ -484,6 +485,10 @@ def _init_weights(self, module):
484485
if module.padding_idx is not None:
485486
module.weight.data[module.padding_idx].zero_()
486487

488+
def _set_gradient_checkpointing(self, module, value=False):
489+
if isinstance(module, (BartDecoder, BartEncoder)):
490+
module.gradient_checkpointing = value
491+
487492
@property
488493
def dummy_inputs(self):
489494
pad_token = self.config.pad_token_id
@@ -687,6 +692,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
687692
self.layernorm_embedding = nn.LayerNorm(embed_dim)
688693

689694
self.init_weights()
695+
self.gradient_checkpointing = False
690696

691697
def forward(
692698
self,
@@ -782,7 +788,7 @@ def forward(
782788
if self.training and (dropout_probability < self.layerdrop): # skip the layer
783789
layer_outputs = (None, None)
784790
else:
785-
if getattr(self.config, "gradient_checkpointing", False) and self.training:
791+
if self.gradient_checkpointing and self.training:
786792

787793
def create_custom_forward(module):
788794
def custom_forward(*inputs):
@@ -849,6 +855,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
849855
self.layernorm_embedding = nn.LayerNorm(config.d_model)
850856

851857
self.init_weights()
858+
self.gradient_checkpointing = False
852859

853860
def get_input_embeddings(self):
854861
return self.embed_tokens
@@ -1020,12 +1027,11 @@ def forward(
10201027

10211028
past_key_value = past_key_values[idx] if past_key_values is not None else None
10221029

1023-
if getattr(self.config, "gradient_checkpointing", False) and self.training:
1030+
if self.gradient_checkpointing and self.training:
10241031

10251032
if use_cache:
10261033
logger.warning(
1027-
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
1028-
"`use_cache=False`..."
1034+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
10291035
)
10301036
use_cache = False
10311037

src/transformers/models/beit/configuration_beit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ class BeitConfig(PretrainedConfig):
5757
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
5858
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
5959
The epsilon used by the layer normalization layers.
60-
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
61-
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
6260
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
6361
The size (resolution) of each image.
6462
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):

src/transformers/models/beit/modeling_beit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def __init__(self, config, window_size=None):
432432
for i in range(config.num_hidden_layers)
433433
]
434434
)
435+
self.gradient_checkpointing = False
435436

436437
def forward(
437438
self,
@@ -450,7 +451,7 @@ def forward(
450451

451452
layer_head_mask = head_mask[i] if head_mask is not None else None
452453

453-
if getattr(self.config, "gradient_checkpointing", False) and self.training:
454+
if self.gradient_checkpointing and self.training:
454455

455456
def create_custom_forward(module):
456457
def custom_forward(*inputs):
@@ -494,6 +495,7 @@ class BeitPreTrainedModel(PreTrainedModel):
494495

495496
config_class = BeitConfig
496497
base_model_prefix = "beit"
498+
supports_gradient_checkpointing = True
497499

498500
def _init_weights(self, module):
499501
"""Initialize the weights"""
@@ -511,6 +513,10 @@ def _init_weights(self, module):
511513
module.bias.data.zero_()
512514
module.weight.data.fill_(1.0)
513515

516+
def _set_gradient_checkpointing(self, module, value=False):
517+
if isinstance(module, BeitEncoder):
518+
module.gradient_checkpointing = value
519+
514520

515521
BEIT_START_DOCSTRING = r"""
516522
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use

src/transformers/models/bert/configuration_bert.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ class BertConfig(PretrainedConfig):
9292
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
9393
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
9494
The epsilon used by the layer normalization layers.
95-
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
96-
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
9795
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
9896
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
9997
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
@@ -137,7 +135,6 @@ def __init__(
137135
initializer_range=0.02,
138136
layer_norm_eps=1e-12,
139137
pad_token_id=0,
140-
gradient_checkpointing=False,
141138
position_embedding_type="absolute",
142139
use_cache=True,
143140
classifier_dropout=None,
@@ -157,7 +154,6 @@ def __init__(
157154
self.type_vocab_size = type_vocab_size
158155
self.initializer_range = initializer_range
159156
self.layer_norm_eps = layer_norm_eps
160-
self.gradient_checkpointing = gradient_checkpointing
161157
self.position_embedding_type = position_embedding_type
162158
self.use_cache = use_cache
163159
self.classifier_dropout = classifier_dropout

src/transformers/models/bert/modeling_bert.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def __init__(self, config):
529529
super().__init__()
530530
self.config = config
531531
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
532+
self.gradient_checkpointing = False
532533

533534
def forward(
534535
self,
@@ -555,12 +556,11 @@ def forward(
555556
layer_head_mask = head_mask[i] if head_mask is not None else None
556557
past_key_value = past_key_values[i] if past_key_values is not None else None
557558

558-
if getattr(self.config, "gradient_checkpointing", False) and self.training:
559+
if self.gradient_checkpointing and self.training:
559560

560561
if use_cache:
561562
logger.warning(
562-
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
563-
"`use_cache=False`..."
563+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
564564
)
565565
use_cache = False
566566

@@ -714,6 +714,7 @@ class BertPreTrainedModel(PreTrainedModel):
714714
config_class = BertConfig
715715
load_tf_weights = load_tf_weights_in_bert
716716
base_model_prefix = "bert"
717+
supports_gradient_checkpointing = True
717718
_keys_to_ignore_on_load_missing = [r"position_ids"]
718719

719720
def _init_weights(self, module):
@@ -732,6 +733,10 @@ def _init_weights(self, module):
732733
module.bias.data.zero_()
733734
module.weight.data.fill_(1.0)
734735

736+
def _set_gradient_checkpointing(self, module, value=False):
737+
if isinstance(module, BertEncoder):
738+
module.gradient_checkpointing = value
739+
735740

736741
@dataclass
737742
class BertForPreTrainingOutput(ModelOutput):

src/transformers/models/bert_generation/configuration_bert_generation.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class BertGenerationConfig(PretrainedConfig):
5252
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
5353
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
5454
The epsilon used by the layer normalization layers.
55-
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
56-
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
5755
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
5856
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
5957
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
@@ -96,7 +94,6 @@ def __init__(
9694
pad_token_id=0,
9795
bos_token_id=2,
9896
eos_token_id=1,
99-
gradient_checkpointing=False,
10097
position_embedding_type="absolute",
10198
use_cache=True,
10299
**kwargs
@@ -114,6 +111,5 @@ def __init__(
114111
self.max_position_embeddings = max_position_embeddings
115112
self.initializer_range = initializer_range
116113
self.layer_norm_eps = layer_norm_eps
117-
self.gradient_checkpointing = gradient_checkpointing
118114
self.position_embedding_type = position_embedding_type
119115
self.use_cache = use_cache

src/transformers/models/big_bird/configuration_big_bird.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ class BigBirdConfig(PretrainedConfig):
8282
num_random_blocks (:obj:`int`, `optional`, defaults to 3)
8383
Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type ==
8484
"block_sparse"`.
85-
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
86-
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
8785
classifier_dropout (:obj:`float`, `optional`):
8886
The dropout ratio for the classification head.
8987
@@ -127,7 +125,6 @@ def __init__(
127125
rescale_embeddings=False,
128126
block_size=64,
129127
num_random_blocks=3,
130-
gradient_checkpointing=False,
131128
classifier_dropout=None,
132129
**kwargs
133130
):
@@ -153,7 +150,6 @@ def __init__(
153150
self.layer_norm_eps = layer_norm_eps
154151
self.use_cache = use_cache
155152
self.is_encoder_decoder = is_encoder_decoder
156-
self.gradient_checkpointing = gradient_checkpointing
157153

158154
self.rescale_embeddings = rescale_embeddings
159155
self.attention_type = attention_type

0 commit comments

Comments
 (0)