Skip to content

Commit 1eefc67

Browse files
authored
Adding gradient_checkpointing to Flax Whisper
It uses `flax.linen.remat` and follows on PRs huggingface#13657 and huggingface#17994
1 parent 6dc0a84 commit 1eefc67

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

src/transformers/models/whisper/modeling_flax_whisper.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jax.numpy as jnp
2424
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
2525
from flax.linen import combine_masks, make_causal_mask
26+
from flax.linen import partitioning as nn_partitioning
2627
from flax.linen.attention import dot_product_attention_weights
2728
from flax.traverse_util import flatten_dict, unflatten_dict
2829
from jax import lax
@@ -53,6 +54,8 @@
5354
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
5455
_CONFIG_FOR_DOC = "WhisperConfig"
5556

57+
remat = nn_partitioning.remat
58+
5659

5760
WHISPER_START_DOCSTRING = r"""
5861
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
@@ -391,12 +394,20 @@ def __call__(
391394
class FlaxWhisperEncoderLayerCollection(nn.Module):
392395
config: WhisperConfig
393396
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
397+
gradient_checkpointing: bool = False
394398

395399
def setup(self):
396-
self.layers = [
397-
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
398-
for i in range(self.config.encoder_layers)
399-
]
400+
if self.gradient_checkpointing:
401+
FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
402+
self.layers = [
403+
FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
404+
for i in range(self.config.encoder_layers)
405+
]
406+
else:
407+
self.layers = [
408+
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
409+
for i in range(self.config.encoder_layers)
410+
]
400411
self.layerdrop = self.config.encoder_layerdrop
401412

402413
def __call__(
@@ -535,12 +546,20 @@ def __call__(
535546
class FlaxWhisperDecoderLayerCollection(nn.Module):
536547
config: WhisperConfig
537548
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
549+
gradient_checkpointing: bool = False
538550

539551
def setup(self):
540-
self.layers = [
541-
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
542-
for i in range(self.config.decoder_layers)
543-
]
552+
if self.gradient_checkpointing:
553+
FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
554+
self.layers = [
555+
FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
556+
for i in range(self.config.encoder_layers)
557+
]
558+
else:
559+
self.layers = [
560+
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
561+
for i in range(self.config.decoder_layers)
562+
]
544563
self.layerdrop = self.config.decoder_layerdrop
545564

546565
def __call__(
@@ -605,6 +624,7 @@ def __call__(
605624
class FlaxWhisperEncoder(nn.Module):
606625
config: WhisperConfig
607626
dtype: jnp.dtype = jnp.float32
627+
gradient_checkpointing: bool = False
608628

609629
def setup(self) -> None:
610630
self.conv1 = nn.Conv(
@@ -628,6 +648,7 @@ def setup(self) -> None:
628648
self.layers = FlaxWhisperEncoderLayerCollection(
629649
self.config,
630650
dtype=self.dtype,
651+
gradient_checkpointing=self.gradient_checkpointing,
631652
)
632653
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
633654

@@ -689,12 +710,13 @@ def __call__(
689710
class FlaxWhisperDecoder(nn.Module):
690711
config: WhisperConfig
691712
dtype: jnp.dtype = jnp.float32
713+
gradient_checkpointing: bool = False
692714

693715
def setup(self) -> None:
694716
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
695717
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)
696718

697-
self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype)
719+
self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
698720

699721
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
700722

@@ -753,10 +775,11 @@ def __call__(
753775
class FlaxWhisperModule(nn.Module):
754776
config: WhisperConfig
755777
dtype: jnp.dtype = jnp.float32
778+
gradient_checkpointing: bool = False
756779

757780
def setup(self) -> None:
758-
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype)
759-
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype)
781+
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
782+
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
760783

761784
def __call__(
762785
self,
@@ -821,11 +844,19 @@ def __init__(
821844
seed: int = 0,
822845
dtype: jnp.dtype = jnp.float32,
823846
_do_init: bool = True,
847+
gradient_checkpointing: bool = False,
824848
**kwargs,
825849
):
826-
module = self.module_class(config=config, dtype=dtype, **kwargs)
850+
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
827851
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
828852

853+
def enable_gradient_checkpointing(self):
854+
self._module = self.module_class(
855+
config=self.config,
856+
dtype=self.dtype,
857+
gradient_checkpointing=True,
858+
)
859+
829860
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
830861
# init input tensors
831862
input_features = jnp.zeros(input_shape, dtype="f4")
@@ -1137,9 +1168,10 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
11371168
class FlaxWhisperForConditionalGenerationModule(nn.Module):
11381169
config: WhisperConfig
11391170
dtype: jnp.dtype = jnp.float32
1171+
gradient_checkpointing: bool = False
11401172

11411173
def setup(self) -> None:
1142-
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype)
1174+
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
11431175
self.lm_head = nn.Dense(
11441176
self.config.vocab_size,
11451177
use_bias=False,

0 commit comments

Comments
 (0)