Skip to content

Commit 9972f38

Browse files
author
sanchit-gandhi
committed
make fix-copies
1 parent d6040e0 commit 9972f38

File tree

3 files changed

+119
-36
lines changed

3 files changed

+119
-36
lines changed

src/transformers/models/big_bird/modeling_flax_big_bird.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,12 +1408,12 @@ def __call__(
14081408
layer_outputs = layer(
14091409
hidden_states,
14101410
attention_mask,
1411-
layer_head_mask=head_mask[i] if head_mask is not None else None,
1412-
encoder_hidden_states=encoder_hidden_states,
1413-
encoder_attention_mask=encoder_attention_mask,
1414-
init_cache=init_cache,
1415-
deterministic=deterministic,
1416-
output_attentions=output_attentions,
1411+
head_mask[i] if head_mask is not None else None,
1412+
encoder_hidden_states,
1413+
encoder_attention_mask,
1414+
init_cache,
1415+
deterministic,
1416+
output_attentions,
14171417
)
14181418

14191419
hidden_states = layer_outputs[0]
@@ -1444,9 +1444,14 @@ def __call__(
14441444
class FlaxBigBirdEncoder(nn.Module):
14451445
config: BigBirdConfig
14461446
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1447+
gradient_checkpointing: bool = False
14471448

14481449
def setup(self):
1449-
self.layer = FlaxBigBirdLayerCollection(self.config, dtype=self.dtype)
1450+
self.layer = FlaxBigBirdLayerCollection(
1451+
self.config,
1452+
dtype=self.dtype,
1453+
gradient_checkpointing=self.gradient_checkpointing,
1454+
)
14501455

14511456
def __call__(
14521457
self,
@@ -1812,9 +1817,14 @@ class FlaxBigBirdModel(FlaxBigBirdPreTrainedModel):
18121817
class FlaxBigBirdForPreTrainingModule(nn.Module):
18131818
config: BigBirdConfig
18141819
dtype: jnp.dtype = jnp.float32
1820+
gradient_checkpointing: bool = False
18151821

18161822
def setup(self):
1817-
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
1823+
self.bert = FlaxBigBirdModule(
1824+
config=self.config,
1825+
dtype=self.dtype,
1826+
gradient_checkpointing=self.gradient_checkpointing,
1827+
)
18181828
self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype)
18191829

18201830
def __call__(
@@ -1910,9 +1920,15 @@ class FlaxBigBirdForPreTraining(FlaxBigBirdPreTrainedModel):
19101920
class FlaxBigBirdForMaskedLMModule(nn.Module):
19111921
config: BigBirdConfig
19121922
dtype: jnp.dtype = jnp.float32
1923+
gradient_checkpointing: bool = False
19131924

19141925
def setup(self):
1915-
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
1926+
self.bert = FlaxBigBirdModule(
1927+
config=self.config,
1928+
add_pooling_layer=False,
1929+
dtype=self.dtype,
1930+
gradient_checkpointing=self.gradient_checkpointing,
1931+
)
19161932
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
19171933

19181934
def __call__(
@@ -2067,9 +2083,14 @@ class FlaxBigBirdForSequenceClassification(FlaxBigBirdPreTrainedModel):
20672083
class FlaxBigBirdForMultipleChoiceModule(nn.Module):
20682084
config: BigBirdConfig
20692085
dtype: jnp.dtype = jnp.float32
2086+
gradient_checkpointing: bool = False
20702087

20712088
def setup(self):
2072-
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
2089+
self.bert = FlaxBigBirdModule(
2090+
config=self.config,
2091+
dtype=self.dtype,
2092+
gradient_checkpointing=self.gradient_checkpointing,
2093+
)
20732094
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
20742095
self.classifier = nn.Dense(1, dtype=self.dtype)
20752096

@@ -2162,9 +2183,15 @@ def __init__(
21622183
class FlaxBigBirdForTokenClassificationModule(nn.Module):
21632184
config: BigBirdConfig
21642185
dtype: jnp.dtype = jnp.float32
2186+
gradient_checkpointing: bool = False
21652187

21662188
def setup(self):
2167-
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
2189+
self.bert = FlaxBigBirdModule(
2190+
config=self.config,
2191+
dtype=self.dtype,
2192+
add_pooling_layer=False,
2193+
gradient_checkpointing=self.gradient_checkpointing,
2194+
)
21682195
classifier_dropout = (
21692196
self.config.classifier_dropout
21702197
if self.config.classifier_dropout is not None
@@ -2414,9 +2441,15 @@ def prepare_question_mask(q_lengths, maxlen: int):
24142441
class FlaxBigBirdForCausalLMModule(nn.Module):
24152442
config: BigBirdConfig
24162443
dtype: jnp.dtype = jnp.float32
2444+
gradient_checkpointing: bool = False
24172445

24182446
def setup(self):
2419-
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
2447+
self.bert = FlaxBigBirdModule(
2448+
config=self.config,
2449+
add_pooling_layer=False,
2450+
dtype=self.dtype,
2451+
gradient_checkpointing=self.gradient_checkpointing,
2452+
)
24202453
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
24212454

24222455
def __call__(

src/transformers/models/electra/modeling_flax_electra.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -521,11 +521,20 @@ def __call__(
521521
class FlaxElectraLayerCollection(nn.Module):
522522
config: ElectraConfig
523523
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
524+
gradient_checkpointing: bool = False
524525

525526
def setup(self):
526-
self.layers = [
527-
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
528-
]
527+
if self.gradient_checkpointing:
528+
FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
529+
self.layers = [
530+
FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
531+
for i in range(self.config.num_hidden_layers)
532+
]
533+
else:
534+
self.layers = [
535+
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
536+
for i in range(self.config.num_hidden_layers)
537+
]
529538

530539
def __call__(
531540
self,
@@ -559,12 +568,12 @@ def __call__(
559568
layer_outputs = layer(
560569
hidden_states,
561570
attention_mask,
562-
layer_head_mask=head_mask[i] if head_mask is not None else None,
563-
encoder_hidden_states=encoder_hidden_states,
564-
encoder_attention_mask=encoder_attention_mask,
565-
init_cache=init_cache,
566-
deterministic=deterministic,
567-
output_attentions=output_attentions,
571+
head_mask[i] if head_mask is not None else None,
572+
encoder_hidden_states,
573+
encoder_attention_mask,
574+
init_cache,
575+
deterministic,
576+
output_attentions,
568577
)
569578

570579
hidden_states = layer_outputs[0]
@@ -595,9 +604,14 @@ def __call__(
595604
class FlaxElectraEncoder(nn.Module):
596605
config: ElectraConfig
597606
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
607+
gradient_checkpointing: bool = False
598608

599609
def setup(self):
600-
self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype)
610+
self.layer = FlaxElectraLayerCollection(
611+
self.config,
612+
dtype=self.dtype,
613+
gradient_checkpointing=self.gradient_checkpointing,
614+
)
601615

602616
def __call__(
603617
self,

src/transformers/models/roberta/modeling_flax_roberta.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,20 @@ def __call__(
511511
class FlaxRobertaLayerCollection(nn.Module):
512512
config: RobertaConfig
513513
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
514+
gradient_checkpointing: bool = False
514515

515516
def setup(self):
516-
self.layers = [
517-
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
518-
]
517+
if self.gradient_checkpointing:
518+
FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
519+
self.layers = [
520+
FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
521+
for i in range(self.config.num_hidden_layers)
522+
]
523+
else:
524+
self.layers = [
525+
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
526+
for i in range(self.config.num_hidden_layers)
527+
]
519528

520529
def __call__(
521530
self,
@@ -549,12 +558,12 @@ def __call__(
549558
layer_outputs = layer(
550559
hidden_states,
551560
attention_mask,
552-
layer_head_mask=head_mask[i] if head_mask is not None else None,
553-
encoder_hidden_states=encoder_hidden_states,
554-
encoder_attention_mask=encoder_attention_mask,
555-
init_cache=init_cache,
556-
deterministic=deterministic,
557-
output_attentions=output_attentions,
561+
head_mask[i] if head_mask is not None else None,
562+
encoder_hidden_states,
563+
encoder_attention_mask,
564+
init_cache,
565+
deterministic,
566+
output_attentions,
558567
)
559568

560569
hidden_states = layer_outputs[0]
@@ -585,9 +594,14 @@ def __call__(
585594
class FlaxRobertaEncoder(nn.Module):
586595
config: RobertaConfig
587596
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
597+
gradient_checkpointing: bool = False
588598

589599
def setup(self):
590-
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
600+
self.layer = FlaxRobertaLayerCollection(
601+
self.config,
602+
dtype=self.dtype,
603+
gradient_checkpointing=self.gradient_checkpointing,
604+
)
591605

592606
def __call__(
593607
self,
@@ -889,10 +903,15 @@ class FlaxRobertaModule(nn.Module):
889903
config: RobertaConfig
890904
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
891905
add_pooling_layer: bool = True
906+
gradient_checkpointing: bool = False
892907

893908
def setup(self):
894909
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
895-
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
910+
self.encoder = FlaxRobertaEncoder(
911+
self.config,
912+
dtype=self.dtype,
913+
gradient_checkpointing=self.gradient_checkpointing,
914+
)
896915
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
897916

898917
def __call__(
@@ -1101,9 +1120,14 @@ class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
11011120
class FlaxRobertaForMultipleChoiceModule(nn.Module):
11021121
config: RobertaConfig
11031122
dtype: jnp.dtype = jnp.float32
1123+
gradient_checkpointing: bool = False
11041124

11051125
def setup(self):
1106-
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype)
1126+
self.roberta = FlaxRobertaModule(
1127+
config=self.config,
1128+
dtype=self.dtype,
1129+
gradient_checkpointing=self.gradient_checkpointing,
1130+
)
11071131
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
11081132
self.classifier = nn.Dense(1, dtype=self.dtype)
11091133

@@ -1181,9 +1205,15 @@ class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
11811205
class FlaxRobertaForTokenClassificationModule(nn.Module):
11821206
config: RobertaConfig
11831207
dtype: jnp.dtype = jnp.float32
1208+
gradient_checkpointing: bool = False
11841209

11851210
def setup(self):
1186-
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
1211+
self.roberta = FlaxRobertaModule(
1212+
config=self.config,
1213+
dtype=self.dtype,
1214+
add_pooling_layer=False,
1215+
gradient_checkpointing=self.gradient_checkpointing,
1216+
)
11871217
classifier_dropout = (
11881218
self.config.classifier_dropout
11891219
if self.config.classifier_dropout is not None
@@ -1255,9 +1285,15 @@ class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
12551285
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
12561286
config: RobertaConfig
12571287
dtype: jnp.dtype = jnp.float32
1288+
gradient_checkpointing: bool = False
12581289

12591290
def setup(self):
1260-
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
1291+
self.roberta = FlaxRobertaModule(
1292+
config=self.config,
1293+
dtype=self.dtype,
1294+
add_pooling_layer=False,
1295+
gradient_checkpointing=self.gradient_checkpointing,
1296+
)
12611297
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
12621298

12631299
def __call__(

0 commit comments

Comments
 (0)