@@ -565,22 +565,14 @@ class RNN(Module):
565
565
>>> y.shape # (batch, time, cell_size)
566
566
(10, 50, 64)
567
567
568
- To support variable length sequences, you can pass a ``segmentation_mask`` which is an integer
569
- array of shape ``(*batch, time)``, where a 1 indicates the element is part of the sequence and a 0 indicates
570
- a padding element. Sequences must be padded to the right, i.e. all elements of a sequence must be
571
- contiguous and padded elements must be to the right of the sequence. For example::
572
-
573
- >>> # 3 sequences with max length 5
574
- >>> segmentation_mask = jnp.array([
575
- ... [1, 1, 1, 0, 0], # length 3
576
- ... [1, 1, 0, 0, 0], # length 2
577
- ... [1, 1, 1, 1, 1], # length 5
578
- ... ])
579
-
580
- We use this integer mask format because its compatible with sequence packing which might get
581
- implemented in the future. The output elements corresponding to padding elements are NOT
582
- zeroed out. If ``return_carry`` is set to ``True`` the carry will be the state of the last
583
- valid element of each sequence.
568
+ To support variable length sequences, you can pass a ``seq_lengths`` which is an integer
569
+ array of shape ``(*batch)`` where each element is the length of the sequence in the batch.
570
+ For example::
571
+
572
+ >>> seq_lengths = jnp.array([3, 2, 5])
573
+
574
+ The output elements corresponding to padding elements are NOT zeroed out. If ``return_carry``
575
+ is set to ``True`` the carry will be the state of the last valid element of each sequence.
584
576
585
577
RNN also accepts some of the arguments of :func:`flax.linen.scan`, by default they are set to
586
578
work with cells like :class:`LSTMCell` and :class:`GRUCell` but they can be overriden as needed.
@@ -601,7 +593,7 @@ class RNN(Module):
601
593
else it will return a tuple of the final carry and the output sequence.
602
594
reverse: if ``reverse=False`` (default) the sequence is processed from left to right and
603
595
returned in the original order, else it will be processed from right to left, and
604
- returned in reverse order. If ``segmentation_mask `` is passed, padding will always remain
596
+ returned in reverse order. If ``seq_lengths `` is passed, padding will always remain
605
597
at the end of the sequence.
606
598
keep_order: if ``keep_order=True``, when ``reverse=True``
607
599
the output will be reversed back to the original order after processing, this is
@@ -641,7 +633,7 @@ def __call__(
641
633
* ,
642
634
initial_carry : Optional [Carry ] = None ,
643
635
init_key : Optional [random .KeyArray ] = None ,
644
- segmentation_mask : Optional [Array ] = None ,
636
+ seq_lengths : Optional [Array ] = None ,
645
637
return_carry : Optional [bool ] = None ,
646
638
time_major : Optional [bool ] = None ,
647
639
reverse : Optional [bool ] = None ,
@@ -660,15 +652,17 @@ def __call__(
660
652
init_key: a PRNG key used to initialize the carry, if not provided
661
653
``jax.random.PRNGKey(0)`` will be used. Most cells will ignore this
662
654
argument.
663
- segmentation_mask: an integer array of shape ``(*batch, time)`` indicating
664
- which elements are part of the sequence and which are padding elements.
655
+ seq_lengths: an optional integer array of shape ``(*batch)`` indicating
656
+ the length of each sequence, elements whose index in the time dimension
657
+ is greater than the corresponding length will be considered padding and
658
+ will be ignored.
665
659
return_carry: if ``return_carry=False`` (default) only the output sequence is returned,
666
660
else it will return a tuple of the final carry and the output sequence.
667
661
time_major: if ``time_major=False`` (default) it will expect inputs with shape
668
662
``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``.
669
663
reverse: overrides the ``reverse`` attribute, if ``reverse=False`` (default) the sequence is
670
664
processed from left to right and returned in the original order, else it will be processed
671
- from right to left, and returned in reverse order. If ``segmentation_mask `` is passed,
665
+ from right to left, and returned in reverse order. If ``seq_lengths `` is passed,
672
666
padding will always remain at the end of the sequence.
673
667
keep_order: overrides the ``keep_order`` attribute, if ``keep_order=True``, when ``reverse=True``
674
668
the output will be reversed back to the original order after processing, this is
@@ -701,7 +695,7 @@ def __call__(
701
695
if reverse :
702
696
inputs = jax .tree_map (
703
697
lambda x : flip_sequences (
704
- x , segmentation_mask , num_batch_dims = len (batch_dims ), time_major = time_major ), # type: ignore
698
+ x , seq_lengths , num_batch_dims = len (batch_dims ), time_major = time_major ), # type: ignore
705
699
inputs )
706
700
707
701
carry : Carry
@@ -721,15 +715,15 @@ def scan_fn(
721
715
# so that we can select the last carry for each sequence later.
722
716
# This uses more memory but is faster than using jnp.where at each
723
717
# iteration. As a small optimization do this when we really need it.
724
- if segmentation_mask is not None and return_carry :
718
+ if seq_lengths is not None and return_carry :
725
719
return carry , (carry , y )
726
720
else :
727
721
return carry , y
728
722
729
723
scan = transforms .scan (
730
724
scan_fn ,
731
725
in_axes = time_axis ,
732
- out_axes = time_axis if segmentation_mask is None else (0 , time_axis ),
726
+ out_axes = time_axis if seq_lengths is None else (0 , time_axis ),
733
727
unroll = self .unroll ,
734
728
variable_axes = self .variable_axes ,
735
729
variable_broadcast = self .variable_broadcast ,
@@ -742,43 +736,39 @@ def scan_fn(
742
736
# Next we select the final carry. If a segmentation mask was provided and
743
737
# return_carry is True we slice the carry history and select the last valid
744
738
# carry for each sequence. Otherwise we just use the last carry.
745
- if segmentation_mask is not None and return_carry :
739
+ if seq_lengths is not None and return_carry :
746
740
_ , (carries , outputs ) = scan_output
747
- # segmentation_mask [None] expands the shape of the mask to match the
741
+ # seq_lengths [None] expands the shape of the mask to match the
748
742
# number of dimensions of the carry.
749
- carry = _select_last (carries , segmentation_mask [ None ], axis = 0 )
743
+ carry = _select_last_carry (carries , seq_lengths )
750
744
else :
751
745
carry , outputs = scan_output
752
746
753
747
if reverse and keep_order :
754
748
outputs = jax .tree_map (
755
749
lambda x : flip_sequences (
756
- x , segmentation_mask , num_batch_dims = len (batch_dims ), time_major = time_major ), # type: ignore
750
+ x , seq_lengths , num_batch_dims = len (batch_dims ), time_major = time_major ), # type: ignore
757
751
outputs )
758
752
759
753
if return_carry :
760
754
return carry , outputs
761
755
else :
762
756
return outputs
763
757
764
- def _select_last (sequence : A , segmentation_mask : jnp .ndarray , axis : int ) -> A :
765
- last_idx = segmentation_mask . sum ( axis = - 1 ) - 1
758
+ def _select_last_carry (sequence : A , seq_lengths : jnp .ndarray ) -> A :
759
+ last_idx = seq_lengths - 1
766
760
767
761
def _slice_array (x : jnp .ndarray ):
768
- _last_idx = _expand_dims_like (last_idx , target = x )
769
- x = jnp .take_along_axis (x , _last_idx , axis = axis )
770
- return x .squeeze (axis = axis )
762
+ return x [last_idx , jnp .arange (x .shape [1 ])]
771
763
772
764
return jax .tree_map (_slice_array , sequence )
773
765
774
766
def _expand_dims_like (x , target ):
775
767
"""Expands the shape of `x` to match `target`'s shape by adding singleton dimensions."""
776
768
return x .reshape (list (x .shape ) + [1 ] * (target .ndim - x .ndim ))
777
769
778
- # TODO: Make flip_sequences a method of RNN and generalize it to work with
779
- # multiple batch dimensions.
780
770
def flip_sequences (
781
- inputs : Array , segmentation_mask : Optional [Array ], num_batch_dims : int , time_major : bool
771
+ inputs : Array , seq_lengths : Optional [Array ], num_batch_dims : int , time_major : bool
782
772
) -> Array :
783
773
"""Flips a sequence of inputs along the time axis.
784
774
@@ -810,19 +800,20 @@ def flip_sequences(
810
800
time_axis = 0 if time_major else num_batch_dims
811
801
max_steps = inputs .shape [time_axis ]
812
802
813
- if segmentation_mask is None :
803
+ if seq_lengths is None :
814
804
# reverse inputs and return
815
805
inputs = jnp .flip (inputs , axis = time_axis )
816
806
return inputs
817
807
818
- lengths = jnp .sum (segmentation_mask , axis = time_axis , keepdims = True ) # [*batch, 1]
808
+ seq_lengths = jnp .expand_dims (seq_lengths , axis = time_axis )
809
+
819
810
# create indexes
820
811
idxs = jnp .arange (max_steps - 1 , - 1 , - 1 ) # [max_steps]
821
812
if time_major :
822
813
idxs = jnp .reshape (idxs , [max_steps ] + [1 ] * num_batch_dims )
823
814
else :
824
815
idxs = jnp .reshape (idxs , [1 ] * num_batch_dims + [max_steps ]) # [1, ..., max_steps]
825
- idxs = (idxs + lengths ) % max_steps # [*batch, max_steps]
816
+ idxs = (idxs + seq_lengths ) % max_steps # [*batch, max_steps]
826
817
idxs = _expand_dims_like (idxs , target = inputs ) # [*batch, max_steps, *features]
827
818
# Select the inputs in flipped order.
828
819
outputs = jnp .take_along_axis (inputs , idxs , axis = time_axis )
@@ -841,7 +832,7 @@ def __call__(
841
832
* ,
842
833
initial_carry : Optional [Carry ] = None ,
843
834
init_key : Optional [random .KeyArray ] = None ,
844
- segmentation_mask : Optional [Array ] = None ,
835
+ seq_lengths : Optional [Array ] = None ,
845
836
return_carry : Optional [bool ] = None ,
846
837
time_major : Optional [bool ] = None ,
847
838
reverse : Optional [bool ] = None ,
@@ -863,7 +854,7 @@ def __call__(
863
854
* ,
864
855
initial_carry : Optional [Carry ] = None ,
865
856
init_key : Optional [random .KeyArray ] = None ,
866
- segmentation_mask : Optional [Array ] = None ,
857
+ seq_lengths : Optional [Array ] = None ,
867
858
return_carry : Optional [bool ] = None ,
868
859
time_major : Optional [bool ] = None ,
869
860
reverse : Optional [bool ] = None ,
@@ -890,12 +881,12 @@ def __call__(
890
881
# Encode in the forward direction.
891
882
carry_forward , outputs_forward = self .forward_rnn (
892
883
inputs , initial_carry = initial_carry_forward , init_key = key_forward ,
893
- segmentation_mask = segmentation_mask , return_carry = True ,
884
+ seq_lengths = seq_lengths , return_carry = True ,
894
885
time_major = time_major , reverse = False )
895
886
896
887
carry_backward , outputs_backward = self .backward_rnn (
897
888
inputs , initial_carry = initial_carry_backward , init_key = key_backward ,
898
- segmentation_mask = segmentation_mask , return_carry = True ,
889
+ seq_lengths = seq_lengths , return_carry = True ,
899
890
time_major = time_major , reverse = True , keep_order = True )
900
891
901
892
carry = (carry_forward , carry_backward )
0 commit comments