23
23
import jax .numpy as jnp
24
24
from flax .core .frozen_dict import FrozenDict , freeze , unfreeze
25
25
from flax .linen import combine_masks , make_causal_mask
26
+ from flax .linen import partitioning as nn_partitioning
26
27
from flax .linen .attention import dot_product_attention_weights
27
28
from flax .traverse_util import flatten_dict , unflatten_dict
28
29
from jax import lax
53
54
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
54
55
_CONFIG_FOR_DOC = "WhisperConfig"
55
56
57
+ remat = nn_partitioning .remat
58
+
56
59
57
60
WHISPER_START_DOCSTRING = r"""
58
61
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
@@ -391,12 +394,20 @@ def __call__(
391
394
class FlaxWhisperEncoderLayerCollection (nn .Module ):
392
395
config : WhisperConfig
393
396
dtype : jnp .dtype = jnp .float32 # the dtype of the computation
397
+ gradient_checkpointing : bool = False
394
398
395
399
def setup (self ):
396
- self .layers = [
397
- FlaxWhisperEncoderLayer (self .config , name = str (i ), dtype = self .dtype )
398
- for i in range (self .config .encoder_layers )
399
- ]
400
+ if self .gradient_checkpointing :
401
+ FlaxWhisperEncoderCheckpointLayer = remat (FlaxWhisperEncoderLayer , static_argnums = (2 , 3 ))
402
+ self .layers = [
403
+ FlaxWhisperEncoderCheckpointLayer (self .config , name = str (i ), dtype = self .dtype )
404
+ for i in range (self .config .encoder_layers )
405
+ ]
406
+ else :
407
+ self .layers = [
408
+ FlaxWhisperEncoderLayer (self .config , name = str (i ), dtype = self .dtype )
409
+ for i in range (self .config .encoder_layers )
410
+ ]
400
411
self .layerdrop = self .config .encoder_layerdrop
401
412
402
413
def __call__ (
@@ -535,12 +546,20 @@ def __call__(
535
546
class FlaxWhisperDecoderLayerCollection (nn .Module ):
536
547
config : WhisperConfig
537
548
dtype : jnp .dtype = jnp .float32 # the dtype of the computation
549
+ gradient_checkpointing : bool = False
538
550
539
551
def setup (self ):
540
- self .layers = [
541
- FlaxWhisperDecoderLayer (self .config , name = str (i ), dtype = self .dtype )
542
- for i in range (self .config .decoder_layers )
543
- ]
552
+ if self .gradient_checkpointing :
553
+ FlaxWhisperDecoderCheckpointLayer = remat (FlaxWhisperDecoderLayer , static_argnums = (4 , 5 , 6 ))
554
+ self .layers = [
555
+ FlaxWhisperDecoderCheckpointLayer (self .config , name = str (i ), dtype = self .dtype )
556
+ for i in range (self .config .encoder_layers )
557
+ ]
558
+ else :
559
+ self .layers = [
560
+ FlaxWhisperDecoderLayer (self .config , name = str (i ), dtype = self .dtype )
561
+ for i in range (self .config .decoder_layers )
562
+ ]
544
563
self .layerdrop = self .config .decoder_layerdrop
545
564
546
565
def __call__ (
@@ -605,6 +624,7 @@ def __call__(
605
624
class FlaxWhisperEncoder (nn .Module ):
606
625
config : WhisperConfig
607
626
dtype : jnp .dtype = jnp .float32
627
+ gradient_checkpointing : bool = False
608
628
609
629
def setup (self ) -> None :
610
630
self .conv1 = nn .Conv (
@@ -628,6 +648,7 @@ def setup(self) -> None:
628
648
self .layers = FlaxWhisperEncoderLayerCollection (
629
649
self .config ,
630
650
dtype = self .dtype ,
651
+ gradient_checkpointing = self .gradient_checkpointing ,
631
652
)
632
653
self .embed_positions = nn .Embed (self .config .max_source_positions , self .config .d_model , dtype = self .dtype )
633
654
@@ -689,12 +710,13 @@ def __call__(
689
710
class FlaxWhisperDecoder (nn .Module ):
690
711
config : WhisperConfig
691
712
dtype : jnp .dtype = jnp .float32
713
+ gradient_checkpointing : bool = False
692
714
693
715
def setup (self ) -> None :
694
716
self .embed_tokens = nn .Embed (self .config .vocab_size , self .config .d_model , dtype = self .dtype )
695
717
self .embed_positions = nn .Embed (self .config .max_target_positions , self .config .d_model , dtype = self .dtype )
696
718
697
- self .layers = FlaxWhisperDecoderLayerCollection (self .config , dtype = self .dtype )
719
+ self .layers = FlaxWhisperDecoderLayerCollection (self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
698
720
699
721
self .dropout_layer = nn .Dropout (rate = self .config .dropout )
700
722
@@ -753,10 +775,11 @@ def __call__(
753
775
class FlaxWhisperModule (nn .Module ):
754
776
config : WhisperConfig
755
777
dtype : jnp .dtype = jnp .float32
778
+ gradient_checkpointing : bool = False
756
779
757
780
def setup (self ) -> None :
758
- self .encoder = FlaxWhisperEncoder (self .config , dtype = self .dtype )
759
- self .decoder = FlaxWhisperDecoder (self .config , dtype = self .dtype )
781
+ self .encoder = FlaxWhisperEncoder (self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
782
+ self .decoder = FlaxWhisperDecoder (self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
760
783
761
784
def __call__ (
762
785
self ,
@@ -821,11 +844,19 @@ def __init__(
821
844
seed : int = 0 ,
822
845
dtype : jnp .dtype = jnp .float32 ,
823
846
_do_init : bool = True ,
847
+ gradient_checkpointing : bool = False ,
824
848
** kwargs ,
825
849
):
826
- module = self .module_class (config = config , dtype = dtype , ** kwargs )
850
+ module = self .module_class (config = config , dtype = dtype , gradient_checkpointing = gradient_checkpointing , ** kwargs )
827
851
super ().__init__ (config , module , input_shape = input_shape , seed = seed , dtype = dtype , _do_init = _do_init )
828
852
853
+ def enable_gradient_checkpointing (self ):
854
+ self ._module = self .module_class (
855
+ config = self .config ,
856
+ dtype = self .dtype ,
857
+ gradient_checkpointing = True ,
858
+ )
859
+
829
860
def init_weights (self , rng : jax .random .PRNGKey , input_shape : Tuple , params : FrozenDict = None ) -> FrozenDict :
830
861
# init input tensors
831
862
input_features = jnp .zeros (input_shape , dtype = "f4" )
@@ -1137,9 +1168,10 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
1137
1168
class FlaxWhisperForConditionalGenerationModule (nn .Module ):
1138
1169
config : WhisperConfig
1139
1170
dtype : jnp .dtype = jnp .float32
1171
+ gradient_checkpointing : bool = False
1140
1172
1141
1173
def setup (self ) -> None :
1142
- self .model = FlaxWhisperModule (config = self .config , dtype = self .dtype )
1174
+ self .model = FlaxWhisperModule (config = self .config , dtype = self .dtype , gradient_checkpointing = self . gradient_checkpointing )
1143
1175
self .lm_head = nn .Dense (
1144
1176
self .config .vocab_size ,
1145
1177
use_bias = False ,
0 commit comments