Skip to content

Make gradient_checkpointing a training argument #13657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/model_doc/led.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ Tips:
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
``config.gradient_checkpointing = True``.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
``model.gradient_checkpointing_enable()``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enable_gradient_checkpointing?

- A notebook showing how to evaluate LED, can be accessed `here
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
- A notebook showing how to fine-tune LED, can be accessed `here
Expand Down
5 changes: 0 additions & 5 deletions examples/pytorch/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,3 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="
```

This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`.

This feature can also be used to activate gradient checkpointing by passing:
```
--config_overrides "gradient_checkpointing=true,use_cache=False"
```
19 changes: 19 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import json
import os
import warnings
from typing import Any, Dict, Tuple, Union

from . import __version__
Expand All @@ -38,6 +39,9 @@
logger = logging.get_logger(__name__)


NO_SAVE_CONFIG_KEYS = ["_gradient_checkpointing"]


class PretrainedConfig(PushToHubMixin):
r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
Expand Down Expand Up @@ -330,6 +334,15 @@ def __init__(self, **kwargs):
# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)

# Deal with gradient checkpointing
if "gradient_checkpointing" in kwargs:
self._gradient_checkpointing = kwargs.pop("gradient_checkpointing")
warnings.warn(
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
"Transformers. Using `model.gradient_checkpointing_enable(True)` instead, or if you are using the "
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
)

# Additional attributes without default values
for key, value in kwargs.items():
try:
Expand Down Expand Up @@ -573,6 +586,11 @@ def get_config_dict(
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")

# Backward compatibility: deal with old model files that may have gradient_checkpointing in their config
# online
if "gradient_checkpointing" in config_dict:
config_dict["_gradient_checkpointing"] = config_dict.pop("gradient_checkpointing")

return config_dict, kwargs

@classmethod
Expand Down Expand Up @@ -682,6 +700,7 @@ def to_dict(self) -> Dict[str, Any]:
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output = {k: v for k, v in output.items() if k not in NO_SAVE_CONFIG_KEYS}
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type

Expand Down
16 changes: 16 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_keys_to_ignore_on_save = None

is_parallelizable = False
supports_gradient_checkpointing = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -932,6 +933,21 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self, flag: bool = True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a disable too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I didn't see this had a flag! Maybe toggle then? Or set_gradient_checkpointing to follow traditional boolean setter conventions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 really wanted the method name to start with gradient_checkpointing to be more easily discoverable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some discussion, with Lysandre, we decided to try gradient_checkpointing_enable and gradient_checkpointing_disable (no args for each).

"""
Activates or deactivates gradient checkpointing for the current model.

Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".

Args:
flag (:obj:`bool`, `optional`, defaults to :obj:`True`):
Will activate gradient checkpointing if :obj:`True`, deactivate it if :obj:`False`.
"""
if not self.supports_gradient_checkpointing and flag:
logger.warn(f"{self.__class__.__name__} does not support gradient checkpointing so nothing will happen.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason not to assert here instead? The user can then change their setup and proceed without problems.

It's a clear error to activate this option if a model doesn't support it, IMHO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to be consistent with the previous behavior where we did nothing if the user input gradient_checkpointing for a model that did not support it.

I'm not opposed to asserting, but let's see what @LysandreJik and @patrickvonplaten think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also be in favor of raising an error here actually. It's a new function so I think we can add this behavior here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will switch then!

self.config._gradient_checkpointing = flag

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ class BartConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Expand Down Expand Up @@ -131,7 +129,6 @@ def __init__(
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
Expand Down Expand Up @@ -161,7 +158,6 @@ def __init__(
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True

super().__init__(
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def forward(self, hidden_states: torch.Tensor):
class BartPretrainedModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]

def _init_weights(self, module):
Expand Down Expand Up @@ -782,7 +783,7 @@ def forward(
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if getattr(self.config, "_gradient_checkpointing", False) and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1020,12 +1021,11 @@ def forward(

past_key_value = past_key_values[idx] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if getattr(self.config, "_gradient_checkpointing", False) and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/beit/configuration_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class BeitConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def forward(

layer_head_mask = head_mask[i] if head_mask is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if getattr(self.config, "_gradient_checkpointing", False) and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -494,6 +494,7 @@ class BeitPreTrainedModel(PreTrainedModel):

config_class = BeitConfig
base_model_prefix = "beit"
supports_gradient_checkpointing = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ class BertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
Expand Down Expand Up @@ -137,7 +135,6 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
Expand All @@ -157,7 +154,6 @@ def __init__(
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,11 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if getattr(self.config, "_gradient_checkpointing", False) and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

Expand Down Expand Up @@ -714,6 +713,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class BertGenerationConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
Expand Down Expand Up @@ -96,7 +94,6 @@ def __init__(
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
**kwargs
Expand All @@ -114,6 +111,5 @@ def __init__(
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
4 changes: 0 additions & 4 deletions src/transformers/models/big_bird/configuration_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ class BigBirdConfig(PretrainedConfig):
num_random_blocks (:obj:`int`, `optional`, defaults to 3)
Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type ==
"block_sparse"`.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
classifier_dropout (:obj:`float`, `optional`):
The dropout ratio for the classification head.

Expand Down Expand Up @@ -127,7 +125,6 @@ def __init__(
rescale_embeddings=False,
block_size=64,
num_random_blocks=3,
gradient_checkpointing=False,
classifier_dropout=None,
**kwargs
):
Expand All @@ -153,7 +150,6 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.is_encoder_decoder = is_encoder_decoder
self.gradient_checkpointing = gradient_checkpointing

self.rescale_embeddings = rescale_embeddings
self.attention_type = attention_type
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,12 +1598,11 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if getattr(self.config, "_gradient_checkpointing", False) and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

Expand Down Expand Up @@ -1756,6 +1755,7 @@ class BigBirdPreTrainedModel(PreTrainedModel):
config_class = BigBirdConfig
load_tf_weights = load_tf_weights_in_big_bird
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ class BigBirdPegasusConfig(PretrainedConfig):
"block_sparse"`.
scale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`)
Whether to rescale embeddings with (hidden_size ** 0.5).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.

Example::

Expand Down Expand Up @@ -141,7 +139,6 @@ def __init__(
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
Expand Down Expand Up @@ -170,7 +167,6 @@ def __init__(
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True

# extra config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,7 @@ def forward(self, hidden_states: torch.Tensor):
class BigBirdPegasusPreTrainedModel(PreTrainedModel):
config_class = BigBirdPegasusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -1894,7 +1895,7 @@ def forward(
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if getattr(self.config, "_gradient_checkpointing", False) and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -2225,12 +2226,11 @@ def forward(

past_key_value = past_key_values[idx] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if getattr(self.config, "_gradient_checkpointing", False) and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class BlenderbotConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Expand Down Expand Up @@ -128,7 +126,6 @@ def __init__(
decoder_start_token_id=1,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
Expand All @@ -155,7 +152,6 @@ def __init__(
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True

super().__init__(
Expand Down
Loading