Skip to content

Commit d6eeb87

Browse files
KMFODAsanchit-gandhi
andauthored
Flax Remat for LongT5 (#17994)
* [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * add gradient_checkpointing to examples * Add gradient_checkpointing to run_mlm_flax * Add remat to longt5 * Add gradient checkpointing test longt5 * Fix args errors * Fix remaining tests * Make fixup & quality fixes * replace kwargs * remove unecessary kwargs * Make fixup changes * revert long_t5_flax changes * Remove return_dict and copy to LongT5 * Remove test_gradient_checkpointing Co-authored-by: sanchit-gandhi <[email protected]>
1 parent 1ccd251 commit d6eeb87

File tree

4 files changed

+149
-39
lines changed

4 files changed

+149
-39
lines changed

examples/flax/language-modeling/run_mlm_flax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ class TrainingArguments:
107107
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
108108
)
109109
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
110+
gradient_checkpointing: bool = field(
111+
default=False,
112+
metadata={
113+
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
114+
},
115+
)
110116

111117
def __post_init__(self):
112118
if self.output_dir is not None:
@@ -640,6 +646,9 @@ def group_texts(examples):
640646
dtype=getattr(jnp, model_args.dtype),
641647
)
642648

649+
if training_args.gradient_checkpointing:
650+
model.enable_gradient_checkpointing()
651+
643652
# Store some constant
644653
num_epochs = int(training_args.num_train_epochs)
645654
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()

examples/flax/summarization/run_summarization_flax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ class TrainingArguments:
121121
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
122122
)
123123
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
124+
gradient_checkpointing: bool = field(
125+
default=False,
126+
metadata={
127+
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
128+
},
129+
)
124130

125131
def __post_init__(self):
126132
if self.output_dir is not None:
@@ -535,6 +541,9 @@ def main():
535541
dtype=getattr(jnp, model_args.dtype),
536542
)
537543

544+
if training_args.gradient_checkpointing:
545+
model.enable_gradient_checkpointing()
546+
538547
if model.config.decoder_start_token_id is None:
539548
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
540549

src/transformers/models/longt5/modeling_flax_longt5.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import jax.numpy as jnp
2626
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
2727
from flax.linen import combine_masks, make_causal_mask
28+
from flax.linen import partitioning as nn_partitioning
2829
from flax.linen.attention import dot_product_attention_weights
2930
from flax.traverse_util import flatten_dict, unflatten_dict
3031
from jax.random import PRNGKey
@@ -53,6 +54,8 @@
5354
_CONFIG_FOR_DOC = "LongT5Config"
5455
_TOKENIZER_FOR_DOC = "T5Tokenizer"
5556

57+
remat = nn_partitioning.remat
58+
5659

5760
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
5861
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
@@ -1356,7 +1359,6 @@ def __call__(
13561359
encoder_attention_mask=None,
13571360
encoder_decoder_position_bias=None,
13581361
output_attentions=False,
1359-
return_dict=True,
13601362
deterministic=True,
13611363
init_cache=False,
13621364
):
@@ -1377,13 +1379,31 @@ def __call__(
13771379
class FlaxLongT5BlockCollection(nn.Module):
13781380
config: LongT5Config
13791381
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1382+
gradient_checkpointing: bool = False
13801383

13811384
def setup(self):
13821385
self.causal = self.config.causal
1383-
self.blocks = [
1384-
FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i))
1385-
for i in range(self.config.num_layers)
1386-
]
1386+
if self.gradient_checkpointing:
1387+
FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8))
1388+
self.blocks = [
1389+
FlaxLongT5CheckpointLayer(
1390+
self.config,
1391+
has_relative_attention_bias=(i == 0),
1392+
dtype=self.dtype,
1393+
name=str(i),
1394+
)
1395+
for i in range(self.config.num_layers)
1396+
]
1397+
else:
1398+
self.blocks = [
1399+
FlaxLongT5LayerCollection(
1400+
self.config,
1401+
has_relative_attention_bias=(i == 0),
1402+
dtype=self.dtype,
1403+
name=str(i),
1404+
)
1405+
for i in range(self.config.num_layers)
1406+
]
13871407

13881408
def __call__(
13891409
self,
@@ -1409,14 +1429,14 @@ def __call__(
14091429

14101430
layer_outputs = layer_module(
14111431
hidden_states,
1412-
attention_mask=attention_mask,
1413-
position_bias=position_bias,
1414-
encoder_hidden_states=encoder_hidden_states,
1415-
encoder_attention_mask=encoder_attention_mask,
1416-
encoder_decoder_position_bias=encoder_decoder_position_bias,
1417-
output_attentions=output_attentions,
1418-
deterministic=deterministic,
1419-
init_cache=init_cache,
1432+
attention_mask,
1433+
position_bias,
1434+
encoder_hidden_states,
1435+
encoder_attention_mask,
1436+
encoder_decoder_position_bias,
1437+
output_attentions,
1438+
deterministic,
1439+
init_cache,
14201440
)
14211441

14221442
hidden_states = layer_outputs[0]
@@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module):
14471467
config: LongT5Config
14481468
embed_tokens: nn.Embed
14491469
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1470+
gradient_checkpointing: bool = False
14501471

14511472
def setup(self):
14521473
self.causal = self.config.causal
14531474

1454-
self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype)
1475+
self.block = FlaxLongT5BlockCollection(
1476+
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
1477+
)
14551478
self.final_layer_norm = FlaxLongT5LayerNorm(
14561479
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
14571480
)
@@ -1989,6 +2012,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
19892012
class FlaxLongT5Module(nn.Module):
19902013
config: LongT5Config
19912014
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
2015+
gradient_checkpointing: bool = False
19922016

19932017
def _get_encoder_module(self):
19942018
return self.encoder
@@ -2005,12 +2029,22 @@ def setup(self):
20052029

20062030
encoder_config = copy.deepcopy(self.config)
20072031
encoder_config.causal = False
2008-
self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
2032+
self.encoder = FlaxLongT5Stack(
2033+
encoder_config,
2034+
embed_tokens=self.shared,
2035+
dtype=self.dtype,
2036+
gradient_checkpointing=self.gradient_checkpointing,
2037+
)
20092038

20102039
decoder_config = copy.deepcopy(self.config)
20112040
decoder_config.causal = True
20122041
decoder_config.num_layers = self.config.num_decoder_layers
2013-
self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
2042+
self.decoder = FlaxLongT5Stack(
2043+
decoder_config,
2044+
embed_tokens=self.shared,
2045+
dtype=self.dtype,
2046+
gradient_checkpointing=self.gradient_checkpointing,
2047+
)
20142048

20152049
def __call__(
20162050
self,
@@ -2104,6 +2138,7 @@ class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
21042138
class FlaxLongT5ForConditionalGenerationModule(nn.Module):
21052139
config: LongT5Config
21062140
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
2141+
gradient_checkpointing: bool = False
21072142

21082143
def _get_encoder_module(self):
21092144
return self.encoder
@@ -2124,13 +2159,17 @@ def setup(self):
21242159
encoder_config.causal = False
21252160
encoder_config.use_cache = False
21262161
encoder_config.is_encoder_decoder = False
2127-
self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype)
2162+
self.encoder = FlaxLongT5Stack(
2163+
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
2164+
)
21282165

21292166
decoder_config = copy.deepcopy(self.config)
21302167
decoder_config.causal = True
21312168
decoder_config.is_encoder_decoder = False
21322169
decoder_config.num_layers = self.config.num_decoder_layers
2133-
self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype)
2170+
self.decoder = FlaxLongT5Stack(
2171+
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
2172+
)
21342173

21352174
self.lm_head = nn.Dense(
21362175
self.config.vocab_size,

0 commit comments

Comments
 (0)