Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit b32608e

Browse files
ArneBinderdirkgr
andauthored
add gradient checkpointing for transformer token embedders (#4544)
* add gradient checkpointing for transformer token embedders * Adds test for gradient checkpointing Co-authored-by: Dirk Groeneveld <[email protected]>
1 parent f639336 commit b32608e

File tree

4 files changed

+29
-4
lines changed

4 files changed

+29
-4
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3535
- Added the option to specify `requires_grad: false` within an optimizer's parameter groups.
3636
- Added the `file-friendly-logging` flag back to the `train` command. Also added this flag to the `predict`, `evaluate`, and `find-learning-rate` commands.
3737
- Added an `EpochCallback` to track current epoch as a model class member.
38+
- Added the option to enable or disable gradient checkpointing for transformer token embedders via boolean parameter `gradient_checkpointing`.
3839

3940
### Removed
4041

allennlp/modules/token_embedders/pretrained_transformer_embedder.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
4040
When `True` (the default), only the final layer of the pretrained transformer is taken
4141
for the embeddings. But if set to `False`, a scalar mix of all of the layers
4242
is used.
43+
gradient_checkpointing: `bool`, optional (default = `None`)
44+
Enable or disable gradient checkpointing.
4345
"""
4446

4547
def __init__(
@@ -51,14 +53,19 @@ def __init__(
5153
train_parameters: bool = True,
5254
last_layer_only: bool = True,
5355
override_weights_file: Optional[str] = None,
54-
override_weights_strip_prefix: Optional[str] = None
56+
override_weights_strip_prefix: Optional[str] = None,
57+
gradient_checkpointing: Optional[bool] = None,
5558
) -> None:
5659
super().__init__()
5760
from allennlp.common import cached_transformers
5861

5962
self.transformer_model = cached_transformers.get(
6063
model_name, True, override_weights_file, override_weights_strip_prefix
6164
)
65+
66+
if gradient_checkpointing is not None:
67+
self.transformer_model.config.update({"gradient_checkpointing": gradient_checkpointing})
68+
6269
self.config = self.transformer_model.config
6370
if sub_module:
6471
assert hasattr(self.transformer_model, sub_module)

allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
3131
When `True` (the default), only the final layer of the pretrained transformer is taken
3232
for the embeddings. But if set to `False`, a scalar mix of all of the layers
3333
is used.
34+
gradient_checkpointing: `bool`, optional (default = `None`)
35+
Enable or disable gradient checkpointing.
3436
"""
3537

3638
def __init__(
@@ -39,6 +41,7 @@ def __init__(
3941
max_length: int = None,
4042
train_parameters: bool = True,
4143
last_layer_only: bool = True,
44+
gradient_checkpointing: Optional[bool] = None,
4245
) -> None:
4346
super().__init__()
4447
# The matched version v.s. mismatched
@@ -47,6 +50,7 @@ def __init__(
4750
max_length=max_length,
4851
train_parameters=train_parameters,
4952
last_layer_only=last_layer_only,
53+
gradient_checkpointing=gradient_checkpointing,
5054
)
5155

5256
@overrides

tests/modules/token_embedders/pretrained_transformer_embedder_test.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,22 @@ def test_forward_runs_when_initialized_from_params(self):
2626
assert tuple(output.size()) == (1, 4, 768)
2727

2828
@pytest.mark.parametrize(
29-
"train_parameters, last_layer_only",
30-
[(True, True), (False, True), (True, False), (False, False)],
29+
"train_parameters, last_layer_only, gradient_checkpointing",
30+
[
31+
(True, True, False),
32+
(False, True, False),
33+
(True, False, False),
34+
(False, False, False),
35+
(
36+
True,
37+
False,
38+
True,
39+
), # checkpointing only makes sense when we're actually training the layers
40+
],
3141
)
32-
def test_end_to_end(self, train_parameters: bool, last_layer_only: bool):
42+
def test_end_to_end(
43+
self, train_parameters: bool, last_layer_only: bool, gradient_checkpointing: bool
44+
):
3345
tokenizer = PretrainedTransformerTokenizer(model_name="bert-base-uncased")
3446
token_indexer = PretrainedTransformerIndexer(model_name="bert-base-uncased")
3547

@@ -53,6 +65,7 @@ def test_end_to_end(self, train_parameters: bool, last_layer_only: bool):
5365
"model_name": "bert-base-uncased",
5466
"train_parameters": train_parameters,
5567
"last_layer_only": last_layer_only,
68+
"gradient_checkpointing": gradient_checkpointing,
5669
}
5770
}
5871
}

0 commit comments

Comments
 (0)