-
Notifications
You must be signed in to change notification settings - Fork 29.5k
Make gradient_checkpointing a training argument #13657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
418f924
3438429
fc703a3
80debb4
0b0ff32
9cad3e0
dd842b3
5a89ff4
ab8c6ca
e286cea
bafe1e0
2ea2a52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import warnings | ||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
from functools import partial | ||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
|
@@ -450,6 +451,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | |
_keys_to_ignore_on_save = None | ||
|
||
is_parallelizable = False | ||
supports_gradient_checkpointing = False | ||
|
||
@property | ||
def dummy_inputs(self) -> Dict[str, torch.Tensor]: | ||
|
@@ -469,6 +471,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): | |
# Save config and origin of the pretrained weights if given in model | ||
self.config = config | ||
self.name_or_path = config.name_or_path | ||
if getattr(self.config, "gradient_checkpointing", False): | ||
self.gradient_checkpointing_enable() | ||
# Remove the attribute now that is has been consumed, so it's no saved in the config. | ||
delattr(self.config, "gradient_checkpointing") | ||
|
||
@classmethod | ||
def _from_config(cls, config, **kwargs): | ||
|
@@ -932,6 +938,27 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): | |
|
||
self.base_model._prune_heads(heads_to_prune) | ||
|
||
def gradient_checkpointing_enable(self, flag: bool = True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be a disable too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I didn't see this had a flag! Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stas00 really wanted the method name to start with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After some discussion, with Lysandre, we decided to try |
||
""" | ||
Activates gradient checkpointing for the current model. | ||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint | ||
activations". | ||
""" | ||
if not self.supports_gradient_checkpointing: | ||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") | ||
self.apply(partial(self._set_gradient_checkpointing, value=True)) | ||
|
||
def gradient_checkpointing_disable(self, flag: bool = True): | ||
""" | ||
Deactivates gradient checkpointing for the current model. | ||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint | ||
activations". | ||
""" | ||
if self.supports_gradient_checkpointing: | ||
self.apply(partial(self._set_gradient_checkpointing, value=False)) | ||
|
||
def save_pretrained( | ||
self, | ||
save_directory: Union[str, os.PathLike], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
enable_gradient_checkpointing
?