-
Notifications
You must be signed in to change notification settings - Fork 29.5k
Make gradient_checkpointing a training argument #13657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
418f924
3438429
fc703a3
80debb4
0b0ff32
9cad3e0
dd842b3
5a89ff4
ab8c6ca
e286cea
bafe1e0
2ea2a52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -450,6 +450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | |
_keys_to_ignore_on_save = None | ||
|
||
is_parallelizable = False | ||
supports_gradient_checkpointing = False | ||
|
||
@property | ||
def dummy_inputs(self) -> Dict[str, torch.Tensor]: | ||
|
@@ -932,6 +933,18 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): | |
|
||
self.base_model._prune_heads(heads_to_prune) | ||
|
||
def gradient_checkpointing_enable(self, flag: bool = True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be a disable too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I didn't see this had a flag! Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stas00 really wanted the method name to start with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After some discussion, with Lysandre, we decided to try |
||
""" | ||
Activates or deactivates gradient checkpointing for the current model. | ||
sgugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
flag (:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
Will activate gradient checkpointing if :obj:`True`, deactivate it if :obj:`False`. | ||
""" | ||
if not self.supports_gradient_checkpointing and flag: | ||
logger.warn(f"{self.__class__.__name__} does not support gradient checkpointing so nothing will happen.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason not to assert here instead? The user can then change their setup and proceed without problems. It's a clear error to activate this option if a model doesn't support it, IMHO. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's to be consistent with the previous behavior where we did nothing if the user input I'm not opposed to asserting, but let's see what @LysandreJik and @patrickvonplaten think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would also be in favor of raising an error here actually. It's a new function so I think we can add this behavior here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will switch then! |
||
self.config._gradient_checkpointing = flag | ||
|
||
def save_pretrained( | ||
self, | ||
save_directory: Union[str, os.PathLike], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
enable_gradient_checkpointing
?