25
25
import jax .numpy as jnp
26
26
from flax .core .frozen_dict import FrozenDict , freeze , unfreeze
27
27
from flax .linen import combine_masks , make_causal_mask
28
+ from flax .linen import partitioning as nn_partitioning
28
29
from flax .linen .attention import dot_product_attention_weights
29
30
from flax .traverse_util import flatten_dict , unflatten_dict
30
31
from jax .random import PRNGKey
53
54
_CONFIG_FOR_DOC = "LongT5Config"
54
55
_TOKENIZER_FOR_DOC = "T5Tokenizer"
55
56
57
+ remat = nn_partitioning .remat
58
+
56
59
57
60
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
58
61
def shift_tokens_right (input_ids : np .array , pad_token_id : int , decoder_start_token_id : int ) -> np .ndarray :
@@ -1356,7 +1359,6 @@ def __call__(
1356
1359
encoder_attention_mask = None ,
1357
1360
encoder_decoder_position_bias = None ,
1358
1361
output_attentions = False ,
1359
- return_dict = True ,
1360
1362
deterministic = True ,
1361
1363
init_cache = False ,
1362
1364
):
@@ -1377,13 +1379,31 @@ def __call__(
1377
1379
class FlaxLongT5BlockCollection (nn .Module ):
1378
1380
config : LongT5Config
1379
1381
dtype : jnp .dtype = jnp .float32 # the dtype of the computation
1382
+ gradient_checkpointing : bool = False
1380
1383
1381
1384
def setup (self ):
1382
1385
self .causal = self .config .causal
1383
- self .blocks = [
1384
- FlaxLongT5LayerCollection (self .config , has_relative_attention_bias = (i == 0 ), dtype = self .dtype , name = str (i ))
1385
- for i in range (self .config .num_layers )
1386
- ]
1386
+ if self .gradient_checkpointing :
1387
+ FlaxLongT5CheckpointLayer = remat (FlaxLongT5LayerCollection , static_argnums = (6 , 7 , 8 ))
1388
+ self .blocks = [
1389
+ FlaxLongT5CheckpointLayer (
1390
+ self .config ,
1391
+ has_relative_attention_bias = (i == 0 ),
1392
+ dtype = self .dtype ,
1393
+ name = str (i ),
1394
+ )
1395
+ for i in range (self .config .num_layers )
1396
+ ]
1397
+ else :
1398
+ self .blocks = [
1399
+ FlaxLongT5LayerCollection (
1400
+ self .config ,
1401
+ has_relative_attention_bias = (i == 0 ),
1402
+ dtype = self .dtype ,
1403
+ name = str (i ),
1404
+ )
1405
+ for i in range (self .config .num_layers )
1406
+ ]
1387
1407
1388
1408
def __call__ (
1389
1409
self ,
@@ -1409,14 +1429,14 @@ def __call__(
1409
1429
1410
1430
layer_outputs = layer_module (
1411
1431
hidden_states ,
1412
- attention_mask = attention_mask ,
1413
- position_bias = position_bias ,
1414
- encoder_hidden_states = encoder_hidden_states ,
1415
- encoder_attention_mask = encoder_attention_mask ,
1416
- encoder_decoder_position_bias = encoder_decoder_position_bias ,
1417
- output_attentions = output_attentions ,
1418
- deterministic = deterministic ,
1419
- init_cache = init_cache ,
1432
+ attention_mask ,
1433
+ position_bias ,
1434
+ encoder_hidden_states ,
1435
+ encoder_attention_mask ,
1436
+ encoder_decoder_position_bias ,
1437
+ output_attentions ,
1438
+ deterministic ,
1439
+ init_cache ,
1420
1440
)
1421
1441
1422
1442
hidden_states = layer_outputs [0 ]
@@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module):
1447
1467
config : LongT5Config
1448
1468
embed_tokens : nn .Embed
1449
1469
dtype : jnp .dtype = jnp .float32 # the dtype of the computation
1470
+ gradient_checkpointing : bool = False
1450
1471
1451
1472
def setup (self ):
1452
1473
self .causal = self .config .causal
1453
1474
1454
- self .block = FlaxLongT5BlockCollection (self .config , dtype = self .dtype )
1475
+ self .block = FlaxLongT5BlockCollection (
1476
+ self .config , dtype = self .dtype , gradient_checkpointing = self .gradient_checkpointing
1477
+ )
1455
1478
self .final_layer_norm = FlaxLongT5LayerNorm (
1456
1479
self .config .d_model , eps = self .config .layer_norm_epsilon , dtype = self .dtype
1457
1480
)
@@ -1989,6 +2012,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
1989
2012
class FlaxLongT5Module (nn .Module ):
1990
2013
config : LongT5Config
1991
2014
dtype : jnp .dtype = jnp .float32 # the dtype of the computation
2015
+ gradient_checkpointing : bool = False
1992
2016
1993
2017
def _get_encoder_module (self ):
1994
2018
return self .encoder
@@ -2005,12 +2029,22 @@ def setup(self):
2005
2029
2006
2030
encoder_config = copy .deepcopy (self .config )
2007
2031
encoder_config .causal = False
2008
- self .encoder = FlaxLongT5Stack (encoder_config , embed_tokens = self .shared , dtype = self .dtype )
2032
+ self .encoder = FlaxLongT5Stack (
2033
+ encoder_config ,
2034
+ embed_tokens = self .shared ,
2035
+ dtype = self .dtype ,
2036
+ gradient_checkpointing = self .gradient_checkpointing ,
2037
+ )
2009
2038
2010
2039
decoder_config = copy .deepcopy (self .config )
2011
2040
decoder_config .causal = True
2012
2041
decoder_config .num_layers = self .config .num_decoder_layers
2013
- self .decoder = FlaxLongT5Stack (decoder_config , embed_tokens = self .shared , dtype = self .dtype )
2042
+ self .decoder = FlaxLongT5Stack (
2043
+ decoder_config ,
2044
+ embed_tokens = self .shared ,
2045
+ dtype = self .dtype ,
2046
+ gradient_checkpointing = self .gradient_checkpointing ,
2047
+ )
2014
2048
2015
2049
def __call__ (
2016
2050
self ,
@@ -2104,6 +2138,7 @@ class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
2104
2138
class FlaxLongT5ForConditionalGenerationModule (nn .Module ):
2105
2139
config : LongT5Config
2106
2140
dtype : jnp .dtype = jnp .float32 # the dtype of the computation
2141
+ gradient_checkpointing : bool = False
2107
2142
2108
2143
def _get_encoder_module (self ):
2109
2144
return self .encoder
@@ -2124,13 +2159,17 @@ def setup(self):
2124
2159
encoder_config .causal = False
2125
2160
encoder_config .use_cache = False
2126
2161
encoder_config .is_encoder_decoder = False
2127
- self .encoder = FlaxLongT5Stack (encoder_config , self .shared , dtype = self .dtype )
2162
+ self .encoder = FlaxLongT5Stack (
2163
+ encoder_config , self .shared , dtype = self .dtype , gradient_checkpointing = self .gradient_checkpointing
2164
+ )
2128
2165
2129
2166
decoder_config = copy .deepcopy (self .config )
2130
2167
decoder_config .causal = True
2131
2168
decoder_config .is_encoder_decoder = False
2132
2169
decoder_config .num_layers = self .config .num_decoder_layers
2133
- self .decoder = FlaxLongT5Stack (decoder_config , self .shared , dtype = self .dtype )
2170
+ self .decoder = FlaxLongT5Stack (
2171
+ decoder_config , self .shared , dtype = self .dtype , gradient_checkpointing = self .gradient_checkpointing
2172
+ )
2134
2173
2135
2174
self .lm_head = nn .Dense (
2136
2175
self .config .vocab_size ,
0 commit comments