Skip to content

Commit e9edd2a

Browse files
author
Flax Authors
committed
Merge pull request #3034 from google:use-seq-lenghts
PiperOrigin-RevId: 529462679
2 parents 7457963 + aa3774c commit e9edd2a

File tree

3 files changed

+59
-79
lines changed

3 files changed

+59
-79
lines changed

examples/seq2seq/models.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,24 @@
2727

2828
Array = jax.Array
2929
PRNGKey = jax.random.KeyArray
30+
LSTMCarry = Tuple[Array, Array]
3031

3132

3233
class DecoderLSTMCell(nn.RNNCellBase):
3334
"""DecoderLSTM Module wrapped in a lifted scan transform.
35+
3436
Attributes:
3537
teacher_force: See docstring on Seq2seq module.
3638
vocab_size: Size of the vocabulary.
3739
"""
40+
3841
teacher_force: bool
3942
vocab_size: int
4043

4144
@nn.compact
42-
def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
45+
def __call__(
46+
self, carry: Tuple[LSTMCarry, Array], x: Array
47+
) -> Tuple[Tuple[LSTMCarry, Array], Tuple[Array, Array]]:
4348
"""Applies the DecoderLSTM model."""
4449
lstm_state, last_prediction = carry
4550
if not self.teacher_force:
@@ -100,21 +105,20 @@ def __call__(self, encoder_inputs: Array,
100105
decoder = nn.RNN(DecoderLSTMCell(self.teacher_force, self.vocab_size), decoder_inputs.shape[-1],
101106
split_rngs={'params': False, 'lstm': True}, name='decoder')
102107

103-
segmentation_mask = self.get_segmentation_mask(encoder_inputs)
108+
seq_lengths = self.get_seq_lengths(encoder_inputs)
104109

105-
encoder_state, _ = encoder(encoder_inputs, segmentation_mask=segmentation_mask)
110+
encoder_state, _ = encoder(encoder_inputs, seq_lengths=seq_lengths)
106111
logits, predictions = decoder(decoder_inputs[:, :-1], initial_carry=(encoder_state, decoder_inputs[:, 0]))
107112

108113
return logits, predictions
109114

110-
def get_segmentation_mask(self, inputs: Array) -> Array:
115+
def get_seq_lengths(self, inputs: Array) -> Array:
111116
"""Get segmentation mask for inputs."""
112117
# undo one-hot encoding
113118
inputs = jnp.argmax(inputs, axis=-1)
114-
# calculate eos index
115-
eos_idx = jnp.argmax(inputs == self.eos_id, axis=-1, keepdims=True)
116-
# create index array
117-
indexes = jnp.arange(inputs.shape[1])
118-
indexes = jnp.broadcast_to(indexes, inputs.shape[:2])
119-
# return mask
120-
return indexes < eos_idx
119+
# calculate sequence lengths
120+
seq_lengths = jnp.argmax(inputs == self.eos_id, axis=-1)
121+
122+
return seq_lengths
123+
124+

flax/linen/recurrent.py

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -565,22 +565,14 @@ class RNN(Module):
565565
>>> y.shape # (batch, time, cell_size)
566566
(10, 50, 64)
567567
568-
To support variable length sequences, you can pass a ``segmentation_mask`` which is an integer
569-
array of shape ``(*batch, time)``, where a 1 indicates the element is part of the sequence and a 0 indicates
570-
a padding element. Sequences must be padded to the right, i.e. all elements of a sequence must be
571-
contiguous and padded elements must be to the right of the sequence. For example::
572-
573-
>>> # 3 sequences with max length 5
574-
>>> segmentation_mask = jnp.array([
575-
... [1, 1, 1, 0, 0], # length 3
576-
... [1, 1, 0, 0, 0], # length 2
577-
... [1, 1, 1, 1, 1], # length 5
578-
... ])
579-
580-
We use this integer mask format because its compatible with sequence packing which might get
581-
implemented in the future. The output elements corresponding to padding elements are NOT
582-
zeroed out. If ``return_carry`` is set to ``True`` the carry will be the state of the last
583-
valid element of each sequence.
568+
To support variable length sequences, you can pass a ``seq_lengths`` which is an integer
569+
array of shape ``(*batch)`` where each element is the length of the sequence in the batch.
570+
For example::
571+
572+
>>> seq_lengths = jnp.array([3, 2, 5])
573+
574+
The output elements corresponding to padding elements are NOT zeroed out. If ``return_carry``
575+
is set to ``True`` the carry will be the state of the last valid element of each sequence.
584576
585577
RNN also accepts some of the arguments of :func:`flax.linen.scan`, by default they are set to
586578
work with cells like :class:`LSTMCell` and :class:`GRUCell` but they can be overriden as needed.
@@ -601,7 +593,7 @@ class RNN(Module):
601593
else it will return a tuple of the final carry and the output sequence.
602594
reverse: if ``reverse=False`` (default) the sequence is processed from left to right and
603595
returned in the original order, else it will be processed from right to left, and
604-
returned in reverse order. If ``segmentation_mask`` is passed, padding will always remain
596+
returned in reverse order. If ``seq_lengths`` is passed, padding will always remain
605597
at the end of the sequence.
606598
keep_order: if ``keep_order=True``, when ``reverse=True``
607599
the output will be reversed back to the original order after processing, this is
@@ -641,7 +633,7 @@ def __call__(
641633
*,
642634
initial_carry: Optional[Carry] = None,
643635
init_key: Optional[random.KeyArray] = None,
644-
segmentation_mask: Optional[Array] = None,
636+
seq_lengths: Optional[Array] = None,
645637
return_carry: Optional[bool] = None,
646638
time_major: Optional[bool] = None,
647639
reverse: Optional[bool] = None,
@@ -660,15 +652,17 @@ def __call__(
660652
init_key: a PRNG key used to initialize the carry, if not provided
661653
``jax.random.PRNGKey(0)`` will be used. Most cells will ignore this
662654
argument.
663-
segmentation_mask: an integer array of shape ``(*batch, time)`` indicating
664-
which elements are part of the sequence and which are padding elements.
655+
seq_lengths: an optional integer array of shape ``(*batch)`` indicating
656+
the length of each sequence, elements whose index in the time dimension
657+
is greater than the corresponding length will be considered padding and
658+
will be ignored.
665659
return_carry: if ``return_carry=False`` (default) only the output sequence is returned,
666660
else it will return a tuple of the final carry and the output sequence.
667661
time_major: if ``time_major=False`` (default) it will expect inputs with shape
668662
``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``.
669663
reverse: overrides the ``reverse`` attribute, if ``reverse=False`` (default) the sequence is
670664
processed from left to right and returned in the original order, else it will be processed
671-
from right to left, and returned in reverse order. If ``segmentation_mask`` is passed,
665+
from right to left, and returned in reverse order. If ``seq_lengths`` is passed,
672666
padding will always remain at the end of the sequence.
673667
keep_order: overrides the ``keep_order`` attribute, if ``keep_order=True``, when ``reverse=True``
674668
the output will be reversed back to the original order after processing, this is
@@ -701,7 +695,7 @@ def __call__(
701695
if reverse:
702696
inputs = jax.tree_map(
703697
lambda x: flip_sequences(
704-
x, segmentation_mask, num_batch_dims=len(batch_dims), time_major=time_major), # type: ignore
698+
x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major), # type: ignore
705699
inputs)
706700

707701
carry: Carry
@@ -721,15 +715,15 @@ def scan_fn(
721715
# so that we can select the last carry for each sequence later.
722716
# This uses more memory but is faster than using jnp.where at each
723717
# iteration. As a small optimization do this when we really need it.
724-
if segmentation_mask is not None and return_carry:
718+
if seq_lengths is not None and return_carry:
725719
return carry, (carry, y)
726720
else:
727721
return carry, y
728722

729723
scan = transforms.scan(
730724
scan_fn,
731725
in_axes=time_axis,
732-
out_axes=time_axis if segmentation_mask is None else (0, time_axis),
726+
out_axes=time_axis if seq_lengths is None else (0, time_axis),
733727
unroll=self.unroll,
734728
variable_axes=self.variable_axes,
735729
variable_broadcast=self.variable_broadcast,
@@ -742,43 +736,39 @@ def scan_fn(
742736
# Next we select the final carry. If a segmentation mask was provided and
743737
# return_carry is True we slice the carry history and select the last valid
744738
# carry for each sequence. Otherwise we just use the last carry.
745-
if segmentation_mask is not None and return_carry:
739+
if seq_lengths is not None and return_carry:
746740
_, (carries, outputs) = scan_output
747-
# segmentation_mask[None] expands the shape of the mask to match the
741+
# seq_lengths[None] expands the shape of the mask to match the
748742
# number of dimensions of the carry.
749-
carry = _select_last(carries, segmentation_mask[None], axis=0)
743+
carry = _select_last_carry(carries, seq_lengths)
750744
else:
751745
carry, outputs = scan_output
752746

753747
if reverse and keep_order:
754748
outputs = jax.tree_map(
755749
lambda x: flip_sequences(
756-
x, segmentation_mask, num_batch_dims=len(batch_dims), time_major=time_major), # type: ignore
750+
x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major), # type: ignore
757751
outputs)
758752

759753
if return_carry:
760754
return carry, outputs
761755
else:
762756
return outputs
763757

764-
def _select_last(sequence: A, segmentation_mask: jnp.ndarray, axis: int) -> A:
765-
last_idx = segmentation_mask.sum(axis=-1) - 1
758+
def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A:
759+
last_idx = seq_lengths - 1
766760

767761
def _slice_array(x: jnp.ndarray):
768-
_last_idx = _expand_dims_like(last_idx, target=x)
769-
x = jnp.take_along_axis(x, _last_idx, axis=axis)
770-
return x.squeeze(axis=axis)
762+
return x[last_idx, jnp.arange(x.shape[1])]
771763

772764
return jax.tree_map(_slice_array, sequence)
773765

774766
def _expand_dims_like(x, target):
775767
"""Expands the shape of `x` to match `target`'s shape by adding singleton dimensions."""
776768
return x.reshape(list(x.shape) + [1] * (target.ndim - x.ndim))
777769

778-
# TODO: Make flip_sequences a method of RNN and generalize it to work with
779-
# multiple batch dimensions.
780770
def flip_sequences(
781-
inputs: Array, segmentation_mask: Optional[Array], num_batch_dims: int, time_major: bool
771+
inputs: Array, seq_lengths: Optional[Array], num_batch_dims: int, time_major: bool
782772
) -> Array:
783773
"""Flips a sequence of inputs along the time axis.
784774
@@ -810,19 +800,20 @@ def flip_sequences(
810800
time_axis = 0 if time_major else num_batch_dims
811801
max_steps = inputs.shape[time_axis]
812802

813-
if segmentation_mask is None:
803+
if seq_lengths is None:
814804
# reverse inputs and return
815805
inputs = jnp.flip(inputs, axis=time_axis)
816806
return inputs
817807

818-
lengths = jnp.sum(segmentation_mask, axis=time_axis, keepdims=True) # [*batch, 1]
808+
seq_lengths = jnp.expand_dims(seq_lengths, axis=time_axis)
809+
819810
# create indexes
820811
idxs = jnp.arange(max_steps - 1, -1, -1) # [max_steps]
821812
if time_major:
822813
idxs = jnp.reshape(idxs, [max_steps] + [1] * num_batch_dims)
823814
else:
824815
idxs = jnp.reshape(idxs, [1] * num_batch_dims + [max_steps]) # [1, ..., max_steps]
825-
idxs = (idxs + lengths) % max_steps # [*batch, max_steps]
816+
idxs = (idxs + seq_lengths) % max_steps # [*batch, max_steps]
826817
idxs = _expand_dims_like(idxs, target=inputs) # [*batch, max_steps, *features]
827818
# Select the inputs in flipped order.
828819
outputs = jnp.take_along_axis(inputs, idxs, axis=time_axis)
@@ -841,7 +832,7 @@ def __call__(
841832
*,
842833
initial_carry: Optional[Carry] = None,
843834
init_key: Optional[random.KeyArray] = None,
844-
segmentation_mask: Optional[Array] = None,
835+
seq_lengths: Optional[Array] = None,
845836
return_carry: Optional[bool] = None,
846837
time_major: Optional[bool] = None,
847838
reverse: Optional[bool] = None,
@@ -863,7 +854,7 @@ def __call__(
863854
*,
864855
initial_carry: Optional[Carry] = None,
865856
init_key: Optional[random.KeyArray] = None,
866-
segmentation_mask: Optional[Array] = None,
857+
seq_lengths: Optional[Array] = None,
867858
return_carry: Optional[bool] = None,
868859
time_major: Optional[bool] = None,
869860
reverse: Optional[bool] = None,
@@ -890,12 +881,12 @@ def __call__(
890881
# Encode in the forward direction.
891882
carry_forward, outputs_forward = self.forward_rnn(
892883
inputs, initial_carry=initial_carry_forward, init_key=key_forward,
893-
segmentation_mask=segmentation_mask, return_carry=True,
884+
seq_lengths=seq_lengths, return_carry=True,
894885
time_major=time_major, reverse=False)
895886

896887
carry_backward, outputs_backward = self.backward_rnn(
897888
inputs, initial_carry=initial_carry_backward, init_key=key_backward,
898-
segmentation_mask=segmentation_mask, return_carry=True,
889+
seq_lengths=seq_lengths, return_carry=True,
899890
time_major=time_major, reverse=True, keep_order=True)
900891

901892
carry = (carry_forward, carry_backward)

tests/linen/linen_recurrent_test.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,12 @@ def test_numerical_equivalence_with_mask(self):
182182

183183
key = jax.random.PRNGKey(0)
184184
seq_lengths = jax.random.randint(key, (batch_size,), minval=1, maxval=seq_len + 1)
185-
segmentation_mask = einops.repeat(
186-
jnp.arange(seq_len), 'time -> batch time', batch=batch_size)
187-
segmentation_mask = (segmentation_mask < seq_lengths[:, None]).astype(jnp.int32)
188185

189186
rnn = nn.RNN(nn.LSTMCell(), channels_out, return_carry=True)
190187

191188
xs = jnp.ones((batch_size, seq_len, channels_in))
192189
ys: jnp.ndarray
193-
(carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs, segmentation_mask=segmentation_mask)
190+
(carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs, seq_lengths=seq_lengths)
194191

195192
cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), channels_out)
196193
cell_params = variables['params']['cell']
@@ -335,51 +332,39 @@ def test_reverse_but_keep_order(self):
335332

336333
def test_flip_sequence(self):
337334
x = jnp.arange(2 * 5).reshape((2, 5))
338-
segmentation_mask = jnp.array([[1, 1, 1, 1, 0], [1, 1, 0, 0, 0]])
335+
seq_lengths = jnp.array([4, 2])
339336

340-
flipped = flip_sequences(x, segmentation_mask, num_batch_dims=1, time_major=False)
337+
flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=False)
341338

342339
self.assertEqual(flipped.shape, (2, 5))
343340
np.testing.assert_allclose(flipped[0, :4], [3, 2, 1, 0])
344341
np.testing.assert_allclose(flipped[1, :2], [6, 5])
345342

346343
def test_flip_sequence_more_feature_dims(self):
347344
x = jnp.arange(2 * 5 * 3).reshape((2, 5, 3))
348-
segmentation_mask = jnp.array([[1, 1, 1, 1, 0], [1, 1, 0, 0, 0]])
345+
seq_lengths = jnp.array([4, 2])
349346

350-
flipped = flip_sequences(x, segmentation_mask, num_batch_dims=1, time_major=False)
347+
flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=False)
351348

352349
self.assertEqual(flipped.shape, (2, 5, 3))
353350
np.testing.assert_allclose(flipped[0, :4], x[0, :4][::-1])
354351
np.testing.assert_allclose(flipped[1, :2], x[1, :2][::-1])
355352

356353
def test_flip_sequence_time_major(self):
357354
x = jnp.arange(2 * 5).reshape((5, 2))
358-
segmentation_mask = jnp.array([
359-
[1, 1],
360-
[1, 1],
361-
[1, 0],
362-
[1, 0],
363-
[0, 0],
364-
])
355+
seq_lengths = jnp.array([4, 2])
365356

366-
flipped = flip_sequences(x, segmentation_mask, num_batch_dims=1, time_major=True)
357+
flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=True)
367358

368359
self.assertEqual(flipped.shape, (5, 2))
369360
np.testing.assert_allclose(flipped[:4, 0], x[:4, 0][::-1])
370361
np.testing.assert_allclose(flipped[:2, 1], x[:2, 1][::-1])
371362

372363
def test_flip_sequence_time_major_more_feature_dims(self):
373364
x = jnp.arange(2 * 5 * 3).reshape((5, 2, 3))
374-
segmentation_mask = jnp.array([
375-
[1, 1],
376-
[1, 1],
377-
[1, 0],
378-
[1, 0],
379-
[0, 0],
380-
])
381-
382-
flipped = flip_sequences(x, segmentation_mask, num_batch_dims=1, time_major=True)
365+
seq_lengths = jnp.array([4, 2])
366+
367+
flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=True)
383368

384369
self.assertEqual(flipped.shape, (5, 2, 3))
385370
np.testing.assert_allclose(flipped[:4, 0], x[:4, 0][::-1])

0 commit comments

Comments
 (0)