From 7dd4c58af3eec20c39cc87ce1ca3a250ff5fb0e5 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 23 Jun 2022 14:28:13 +0100 Subject: [PATCH 01/11] [Flax] Add remat (gradient checkpointing) --- src/transformers/modeling_flax_utils.py | 7 +- .../models/bert/modeling_flax_bert.py | 84 +++++++++++++++---- tests/models/bert/test_modeling_flax_bert.py | 22 +++++ 3 files changed, 95 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 74124209cb68..fd55d63c2c38 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -190,6 +190,7 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, ): if config is None: raise ValueError("config cannot be None") @@ -205,6 +206,7 @@ def __init__( self.key = PRNGKey(seed) self.dtype = dtype self.input_shape = input_shape + self.gradient_checkpointing = gradient_checkpointing # To check if the model was intialized automatically. self._is_initialized = _do_init @@ -594,6 +596,7 @@ def from_pretrained( from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) _do_init = kwargs.pop("_do_init", True) + gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} if from_pipeline is not None: @@ -774,7 +777,9 @@ def from_pretrained( ) # init random models - model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) + model = cls( + config, *model_args, _do_init=_do_init, gradient_checkpointing=gradient_checkpointing, **model_kwargs + ) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 902d6cca3d13..b0e50a8b6362 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -56,6 +57,8 @@ _CONFIG_FOR_DOC = "BertConfig" _TOKENIZER_FOR_DOC = "BertTokenizer" +remat = nn_partitioning.remat + @flax.struct.dataclass class FlaxBertForPreTrainingOutput(ModelOutput): @@ -544,10 +547,15 @@ def __call__( class FlaxBertLayerCollection(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): + FlaxBertRematLayer = ( + remat(FlaxBertLayer, static_argnums=(5, 6, 7)) if self.gradient_checkpointing else FlaxBertLayer + ) self.layers = [ - FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + FlaxBertRematLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) ] def __call__( @@ -582,12 +590,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -617,9 +625,12 @@ def __call__( class FlaxBertEncoder(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxBertLayerCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) def __call__( self, @@ -925,10 +936,13 @@ class FlaxBertModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxBertEncoder( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) def __call__( @@ -1003,9 +1017,12 @@ class FlaxBertModel(FlaxBertPreTrainedModel): class FlaxBertForPreTrainingModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) def __call__( @@ -1099,9 +1116,15 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel): class FlaxBertForMaskedLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1161,9 +1184,12 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): class FlaxBertForNextSentencePredictionModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) def __call__( @@ -1248,9 +1274,12 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel): class FlaxBertForSequenceClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1324,9 +1353,12 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): class FlaxBertForMultipleChoiceModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1399,9 +1431,15 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel): class FlaxBertForTokenClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1468,9 +1506,15 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel): class FlaxBertForQuestionAnsweringModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -1539,9 +1583,15 @@ class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): class FlaxBertForCausalLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( diff --git a/tests/models/bert/test_modeling_flax_bert.py b/tests/models/bert/test_modeling_flax_bert.py index 5516c4d6fe67..5363d489736f 100644 --- a/tests/models/bert/test_modeling_flax_bert.py +++ b/tests/models/bert/test_modeling_flax_bert.py @@ -155,6 +155,28 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxBertModelTester(self) + def test_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model1 = model_class(config, gradient_checkpointing=False) + model2 = model_class(config, gradient_checkpointing=True) + + outputs1 = model1(**prepared_inputs_dict) + outputs2 = model2(**prepared_inputs_dict) + + # ensure that the dicts of outputs contain the same keys + self.assertEqual(outputs1.keys(), outputs2.keys()) + + outputs1 = outputs1.to_tuple() + outputs2 = outputs2.to_tuple() + + # ensure that the outputs remain precisely equal + for output1, output2 in zip(outputs1, outputs2): + self.assertTrue((output1 == output2).all()) + @slow def test_model_from_pretrained(self): # Only check this for base model, not necessary for all model classes. From 03606c12abcef8c86aca4594cb921729138bbe5b Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 23 Jun 2022 14:42:17 +0100 Subject: [PATCH 02/11] fix variable naming in test --- tests/models/bert/test_modeling_flax_bert.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/models/bert/test_modeling_flax_bert.py b/tests/models/bert/test_modeling_flax_bert.py index 5363d489736f..b8f86a5f8739 100644 --- a/tests/models/bert/test_modeling_flax_bert.py +++ b/tests/models/bert/test_modeling_flax_bert.py @@ -161,21 +161,21 @@ def test_gradient_checkpointing(self): for model_class in self.all_model_classes: # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model1 = model_class(config, gradient_checkpointing=False) - model2 = model_class(config, gradient_checkpointing=True) + model = model_class(config, gradient_checkpointing=False) + remat_model = model_class(config, gradient_checkpointing=True) - outputs1 = model1(**prepared_inputs_dict) - outputs2 = model2(**prepared_inputs_dict) + outputs = model(**prepared_inputs_dict) + remat_outputs = remat_model(**prepared_inputs_dict) # ensure that the dicts of outputs contain the same keys - self.assertEqual(outputs1.keys(), outputs2.keys()) + self.assertEqual(outputs.keys(), remat_outputs.keys()) - outputs1 = outputs1.to_tuple() - outputs2 = outputs2.to_tuple() + outputs = outputs.to_tuple() + remat_outputs = remat_outputs.to_tuple() # ensure that the outputs remain precisely equal - for output1, output2 in zip(outputs1, outputs2): - self.assertTrue((output1 == output2).all()) + for output, remat_output in zip(outputs, remat_outputs): + self.assertTrue((output == remat_output).all()) @slow def test_model_from_pretrained(self): From 9b6e16447a872cc5e6d198ab91689aab6709242c Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Fri, 24 Jun 2022 15:53:24 +0100 Subject: [PATCH 03/11] flip: checkpoint using a method --- src/transformers/modeling_flax_utils.py | 10 ++- .../models/bert/modeling_flax_bert.py | 70 ++++++++++++++++--- tests/models/bert/test_modeling_flax_bert.py | 5 +- 3 files changed, 66 insertions(+), 19 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index fd55d63c2c38..77eaa900de62 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -190,7 +190,6 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, - gradient_checkpointing: bool = False, ): if config is None: raise ValueError("config cannot be None") @@ -206,7 +205,6 @@ def __init__( self.key = PRNGKey(seed) self.dtype = dtype self.input_shape = input_shape - self.gradient_checkpointing = gradient_checkpointing # To check if the model was intialized automatically. self._is_initialized = _do_init @@ -237,6 +235,9 @@ def __init__( def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: raise NotImplementedError(f"init method has to be implemented for {self}") + def enable_gradient_checkpointing(self): + raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}") + @classmethod def _from_config(cls, config, **kwargs): """ @@ -596,7 +597,6 @@ def from_pretrained( from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) _do_init = kwargs.pop("_do_init", True) - gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} if from_pipeline is not None: @@ -777,9 +777,7 @@ def from_pretrained( ) # init random models - model = cls( - config, *model_args, _do_init=_do_init, gradient_checkpointing=gradient_checkpointing, **model_kwargs - ) + model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index b0e50a8b6362..d7a0e48ff35c 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -548,13 +548,16 @@ class FlaxBertLayerCollection(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): - FlaxBertRematLayer = ( - remat(FlaxBertLayer, static_argnums=(5, 6, 7)) if self.gradient_checkpointing else FlaxBertLayer - ) + if self.gradient_checkpointing: + FlaxBertBlockLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) + else: + FlaxBertBlockLayer = FlaxBertLayer + self.layers = [ - FlaxBertRematLayer(self.config, name=str(i), dtype=self.dtype) + FlaxBertBlockLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] @@ -626,10 +629,14 @@ class FlaxBertEncoder(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.layer = FlaxBertLayerCollection( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) def __call__( @@ -767,11 +774,24 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + _gradient_checkpointing: bool = False, + _remat_policy: Callable[..., bool] = (None,), **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class( + config=config, + dtype=dtype, + gradient_checkpointing=_gradient_checkpointing, + remat_policy=_remat_policy, + **kwargs, + ) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + def enable_gradient_checkpointing(self, remat_policy=None): + self._module = self.module_class( + config=self.config, dtype=self.dtype, gradient_checkpointing=True, remat_policy=remat_policy + ) + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -937,11 +957,15 @@ class FlaxBertModule(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) self.encoder = FlaxBertEncoder( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) @@ -1018,10 +1042,14 @@ class FlaxBertForPreTrainingModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) @@ -1117,6 +1145,7 @@ class FlaxBertForMaskedLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1124,6 +1153,7 @@ def setup(self): add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) @@ -1185,10 +1215,14 @@ class FlaxBertForNextSentencePredictionModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) @@ -1275,10 +1309,14 @@ class FlaxBertForSequenceClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) classifier_dropout = ( self.config.classifier_dropout @@ -1354,10 +1392,14 @@ class FlaxBertForMultipleChoiceModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1432,6 +1474,7 @@ class FlaxBertForTokenClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1439,6 +1482,7 @@ def setup(self): dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) classifier_dropout = ( self.config.classifier_dropout @@ -1507,6 +1551,7 @@ class FlaxBertForQuestionAnsweringModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1514,6 +1559,7 @@ def setup(self): dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) @@ -1584,6 +1630,7 @@ class FlaxBertForCausalLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False + remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1591,6 +1638,7 @@ def setup(self): add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, + remat_policy=self.remat_policy, ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) diff --git a/tests/models/bert/test_modeling_flax_bert.py b/tests/models/bert/test_modeling_flax_bert.py index b8f86a5f8739..4393e18fd3f3 100644 --- a/tests/models/bert/test_modeling_flax_bert.py +++ b/tests/models/bert/test_modeling_flax_bert.py @@ -161,8 +161,9 @@ def test_gradient_checkpointing(self): for model_class in self.all_model_classes: # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config, gradient_checkpointing=False) - remat_model = model_class(config, gradient_checkpointing=True) + model = model_class(config) + remat_model = model_class(config) + remat_model.enable_gradient_checkpointing() outputs = model(**prepared_inputs_dict) remat_outputs = remat_model(**prepared_inputs_dict) From 395221f8c6c5fc31df8f1366ca3e1aa47c490a3a Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Mon, 27 Jun 2022 11:23:23 +0100 Subject: [PATCH 04/11] fix naming --- src/transformers/models/bert/modeling_flax_bert.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index d7a0e48ff35c..173f158a56ec 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -774,15 +774,15 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, - _gradient_checkpointing: bool = False, - _remat_policy: Callable[..., bool] = (None,), + gradient_checkpointing: bool = False, + remat_policy: Callable[..., bool] = (None,), **kwargs ): module = self.module_class( config=config, dtype=dtype, - gradient_checkpointing=_gradient_checkpointing, - remat_policy=_remat_policy, + gradient_checkpointing=gradient_checkpointing, + remat_policy=remat_policy, **kwargs, ) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) From 7b165d53eb9d14fb29de41fa68b055f82e21178d Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Mon, 27 Jun 2022 11:53:49 +0100 Subject: [PATCH 05/11] fix class naming --- src/transformers/models/bert/modeling_flax_bert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 173f158a56ec..b0c5ccdac6e9 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -552,12 +552,12 @@ class FlaxBertLayerCollection(nn.Module): def setup(self): if self.gradient_checkpointing: - FlaxBertBlockLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) + FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) else: - FlaxBertBlockLayer = FlaxBertLayer + FlaxBertCheckpointLayer = FlaxBertLayer self.layers = [ - FlaxBertBlockLayer(self.config, name=str(i), dtype=self.dtype) + FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] From d6040e0100648242dc997a253c33098b5a8f7d88 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 28 Jun 2022 13:39:58 +0100 Subject: [PATCH 06/11] apply PVP's suggestions from code review --- .../models/bert/modeling_flax_bert.py | 44 +++++-------------- 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index b0c5ccdac6e9..8daa866be105 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -548,18 +548,18 @@ class FlaxBertLayerCollection(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): if self.gradient_checkpointing: - FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) + FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] else: - FlaxBertCheckpointLayer = FlaxBertLayer - - self.layers = [ - FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] + self.layers = [ + FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -629,14 +629,12 @@ class FlaxBertEncoder(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.layer = FlaxBertLayerCollection( self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) def __call__( @@ -775,21 +773,21 @@ def __init__( dtype: jnp.dtype = jnp.float32, _do_init: bool = True, gradient_checkpointing: bool = False, - remat_policy: Callable[..., bool] = (None,), **kwargs ): module = self.module_class( config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, - remat_policy=remat_policy, **kwargs, ) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def enable_gradient_checkpointing(self, remat_policy=None): + def enable_gradient_checkpointing(self): self._module = self.module_class( - config=self.config, dtype=self.dtype, gradient_checkpointing=True, remat_policy=remat_policy + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, ) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: @@ -957,7 +955,6 @@ class FlaxBertModule(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) @@ -965,7 +962,6 @@ def setup(self): self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) @@ -1042,14 +1038,12 @@ class FlaxBertForPreTrainingModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) @@ -1145,7 +1139,6 @@ class FlaxBertForMaskedLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1153,7 +1146,6 @@ def setup(self): add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) @@ -1215,14 +1207,12 @@ class FlaxBertForNextSentencePredictionModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) @@ -1309,14 +1299,12 @@ class FlaxBertForSequenceClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) classifier_dropout = ( self.config.classifier_dropout @@ -1392,14 +1380,12 @@ class FlaxBertForMultipleChoiceModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1474,7 +1460,6 @@ class FlaxBertForTokenClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1482,7 +1467,6 @@ def setup(self): dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) classifier_dropout = ( self.config.classifier_dropout @@ -1551,7 +1535,6 @@ class FlaxBertForQuestionAnsweringModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1559,7 +1542,6 @@ def setup(self): dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) @@ -1630,7 +1612,6 @@ class FlaxBertForCausalLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 gradient_checkpointing: bool = False - remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy def setup(self): self.bert = FlaxBertModule( @@ -1638,7 +1619,6 @@ def setup(self): add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, - remat_policy=self.remat_policy, ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) From 9972f381f5b45d97291635bf99159d843971d98f Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 30 Jun 2022 17:37:53 +0100 Subject: [PATCH 07/11] make fix-copies --- .../models/big_bird/modeling_flax_big_bird.py | 57 +++++++++++++---- .../models/electra/modeling_flax_electra.py | 34 +++++++--- .../models/roberta/modeling_flax_roberta.py | 64 +++++++++++++++---- 3 files changed, 119 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 7d5f64a7e38b..42623e916bd0 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -1408,12 +1408,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -1444,9 +1444,14 @@ def __call__( class FlaxBigBirdEncoder(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxBigBirdLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxBigBirdLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -1812,9 +1817,14 @@ class FlaxBigBirdModel(FlaxBigBirdPreTrainedModel): class FlaxBigBirdForPreTrainingModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype) def __call__( @@ -1910,9 +1920,15 @@ class FlaxBigBirdForPreTraining(FlaxBigBirdPreTrainedModel): class FlaxBigBirdForMaskedLMModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -2067,9 +2083,14 @@ class FlaxBigBirdForSequenceClassification(FlaxBigBirdPreTrainedModel): class FlaxBigBirdForMultipleChoiceModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -2162,9 +2183,15 @@ def __init__( class FlaxBigBirdForTokenClassificationModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -2414,9 +2441,15 @@ def prepare_question_mask(q_lengths, maxlen: int): class FlaxBigBirdForCausalLMModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 3e3a7103f07e..6d8aa3e981c6 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -521,11 +521,20 @@ def __call__( class FlaxElectraLayerCollection(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -559,12 +568,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -595,9 +604,14 @@ def __call__( class FlaxElectraEncoder(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxElectraLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 84bf15da6d86..027dd398e75e 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -511,11 +511,20 @@ def __call__( class FlaxRobertaLayerCollection(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -549,12 +558,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -585,9 +594,14 @@ def __call__( class FlaxRobertaEncoder(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxRobertaLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -889,10 +903,15 @@ class FlaxRobertaModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxRobertaEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) def __call__( @@ -1101,9 +1120,14 @@ class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel): class FlaxRobertaForMultipleChoiceModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1181,9 +1205,15 @@ class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel): class FlaxRobertaForTokenClassificationModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1255,9 +1285,15 @@ class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel): class FlaxRobertaForQuestionAnsweringModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( From 0a0b6bfa2b06f67cd579bc4d16ba97242eef4f6a Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 30 Jun 2022 18:16:17 +0100 Subject: [PATCH 08/11] fix big-bird, electra, roberta --- .../models/big_bird/modeling_flax_big_bird.py | 50 +++++++++++++---- .../models/electra/modeling_flax_electra.py | 54 +++++++++++++++---- .../models/roberta/modeling_flax_roberta.py | 38 +++++++++++-- 3 files changed, 120 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 42623e916bd0..4af2d4c564be 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -54,6 +55,8 @@ _CONFIG_FOR_DOC = "BigBirdConfig" _TOKENIZER_FOR_DOC = "BigBirdTokenizer" +remat = nn_partitioning.remat + @flax.struct.dataclass class FlaxBigBirdForPreTrainingOutput(ModelOutput): @@ -1365,17 +1368,25 @@ def __call__( return outputs +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->BigBird class FlaxBigBirdLayerCollection(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxBigBirdCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxBigBirdLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird def __call__( self, hidden_states, @@ -1564,9 +1575,10 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) if config.attention_type == "block_sparse" and input_shape is None: input_shape = (1, 12 * config.block_size) elif input_shape is None: @@ -1574,6 +1586,14 @@ def __init__( super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors @@ -1740,10 +1760,13 @@ class FlaxBigBirdModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxBigBirdEncoder( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.pooler = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), @@ -2015,9 +2038,12 @@ def __call__(self, features, deterministic=True): class FlaxBigBirdForSequenceClassificationModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype) def __call__( @@ -2282,10 +2308,16 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 add_pooling_layer: bool = False + gradient_checkpointing: bool = False def setup(self): self.config.num_labels = 2 - self.bert = FlaxBigBirdModule(self.config, dtype=self.dtype, add_pooling_layer=self.add_pooling_layer) + self.bert = FlaxBigBirdModule( + self.config, + dtype=self.dtype, + add_pooling_layer=self.add_pooling_layer, + gradient_checkpointing=self.gradient_checkpointing, + ) self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype) def __call__( diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 6d8aa3e981c6..5f02c01a650e 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -54,6 +55,8 @@ _CONFIG_FOR_DOC = "ElectraConfig" _TOKENIZER_FOR_DOC = "ElectraTokenizer" +remat = nn_partitioning.remat + @flax.struct.dataclass class FlaxElectraForPreTrainingOutput(ModelOutput): @@ -689,11 +692,20 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors @@ -859,12 +871,15 @@ def __call__( class FlaxElectraModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) if self.config.embedding_size != self.config.hidden_size: self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxElectraEncoder( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) def __call__( self, @@ -939,9 +954,12 @@ def __call__(self, x, kernel): class FlaxElectraForMaskedLMModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) if self.config.tie_word_embeddings: self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) @@ -1003,9 +1021,12 @@ class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): class FlaxElectraForPreTrainingModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) def __call__( @@ -1088,9 +1109,12 @@ class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel): class FlaxElectraForTokenClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1232,9 +1256,12 @@ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): class FlaxElectraForMultipleChoiceModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1311,9 +1338,12 @@ class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): class FlaxElectraForQuestionAnsweringModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -1406,9 +1436,12 @@ def __call__(self, hidden_states, deterministic: bool = True): class FlaxElectraForSequenceClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) def __call__( @@ -1471,9 +1504,12 @@ class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): class FlaxElectraForCausalLMModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) if self.config.tie_word_embeddings: self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 027dd398e75e..ddd6359b36be 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -21,6 +21,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -47,6 +48,8 @@ _CONFIG_FOR_DOC = "RobertaConfig" _TOKENIZER_FOR_DOC = "RobertaTokenizer" +remat = nn_partitioning.remat + def create_position_ids_from_input_ids(input_ids, padding_idx): """ @@ -733,11 +736,20 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -986,9 +998,15 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel): class FlaxRobertaForMaskedLMModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1053,9 +1071,15 @@ class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel): class FlaxRobertaForSequenceClassificationModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype) def __call__( @@ -1362,9 +1386,15 @@ class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel): class FlaxRobertaForCausalLMModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) def __call__( From 70b7175e12cdf76e7c0c1f3adf4bf66233b3230d Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 30 Jun 2022 18:16:27 +0100 Subject: [PATCH 09/11] cookie-cutter --- ...ax_{{cookiecutter.lowercase_modelname}}.py | 67 +++++++++++++------ 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 451dc03f62ed..676270c131fb 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze, freeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.traverse_util import flatten_dict, unflatten_dict from flax.linen.attention import dot_product_attention_weights from jax import lax @@ -126,6 +127,8 @@ """ +remat = nn_partitioning.remat + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}} @@ -507,11 +510,19 @@ def __call__( class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7)) + self.layers = [ + Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -545,12 +556,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -581,9 +592,10 @@ def __call__( class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype) + self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) def __call__( self, @@ -725,11 +737,20 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}} def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors @@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype) - self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype) + self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype) def __call__( @@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1030,9 +1053,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM(Flax{{cookiecutter.cam class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1092,9 +1116,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.cam class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense( self.config.num_labels, @@ -1163,9 +1188,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification(Flax{{co class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1238,9 +1264,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice(Flax{{cookiecutt class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) @@ -1302,9 +1329,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification(Flax{{cooki class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -1373,9 +1401,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(Flax{{cookiec class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( From 0665a33676d8cf2498749ef6f8980a304115aee1 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Fri, 1 Jul 2022 14:53:12 +0100 Subject: [PATCH 10/11] fix flax big-bird --- src/transformers/models/big_bird/modeling_flax_big_bird.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 4af2d4c564be..4ba109e7ab60 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -1368,7 +1368,6 @@ def __call__( return outputs -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->BigBird class FlaxBigBirdLayerCollection(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -1378,15 +1377,16 @@ def setup(self): if self.gradient_checkpointing: FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7)) self.layers = [ - FlaxBigBirdCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] else: self.layers = [ - FlaxBigBirdLayer(self.config, name=str(i), dtype=self.dtype) + FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird def __call__( self, hidden_states, From 66c4b14d061deb69e360f1105e7e48dbcb1684ee Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Fri, 1 Jul 2022 15:06:55 +0100 Subject: [PATCH 11/11] move test to common --- tests/models/bert/test_modeling_flax_bert.py | 23 ----------------- tests/test_modeling_flax_common.py | 27 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/tests/models/bert/test_modeling_flax_bert.py b/tests/models/bert/test_modeling_flax_bert.py index 4393e18fd3f3..5516c4d6fe67 100644 --- a/tests/models/bert/test_modeling_flax_bert.py +++ b/tests/models/bert/test_modeling_flax_bert.py @@ -155,29 +155,6 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxBertModelTester(self) - def test_gradient_checkpointing(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config) - remat_model = model_class(config) - remat_model.enable_gradient_checkpointing() - - outputs = model(**prepared_inputs_dict) - remat_outputs = remat_model(**prepared_inputs_dict) - - # ensure that the dicts of outputs contain the same keys - self.assertEqual(outputs.keys(), remat_outputs.keys()) - - outputs = outputs.to_tuple() - remat_outputs = remat_outputs.to_tuple() - - # ensure that the outputs remain precisely equal - for output, remat_output in zip(outputs, remat_outputs): - self.assertTrue((output == remat_output).all()) - @slow def test_model_from_pretrained(self): # Only check this for base model, not necessary for all model classes. diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index ec3c1fcd0bc3..f90615efea36 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,33 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + def test_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + remat_model = model_class(config) + + try: + remat_model.enable_gradient_checkpointing() + except NotImplementedError: + continue + + outputs = model(**prepared_inputs_dict) + remat_outputs = remat_model(**prepared_inputs_dict) + + # ensure that the dicts of outputs contain the same keys + self.assertEqual(outputs.keys(), remat_outputs.keys()) + + outputs = outputs.to_tuple() + remat_outputs = remat_outputs.to_tuple() + + # ensure that the outputs remain precisely equal + for output, remat_output in zip(outputs, remat_outputs): + self.assertTrue((output == remat_output).all()) + @require_flax @is_staging_test