Skip to content

Add RNN FLIP #2585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 238 additions & 0 deletions docs/flip/2396-rnn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# RNN Flip

- Start Date: 2022-08-18
- FLIP PR: [#2604](https://github.com/google/flax/pull/2604)
- FLIP Issue: [#2396](https://github.com/google/flax/issues/2396)
- Authors: Jasmijn Bastings (@bastings) and Cristian Garcia (@cgarciae)

## Summary
This FLIP adds support for higher-level recurrent layers (RNN, GRU, LSTM) that can help users process input sequences using the recurrent cells already available in Flax.

## Motivation
Implementing well known recurrent architectures is tricky and prone to user errors, even a simple LSTM layers involves the manual creation and handling of the carry/memory and correctly setting up `nn.scan`:

```python
@nn.compact
def __call__(self, x):
LSTM = nn.scan(
nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False}
)
carry = LSTM.initialize_carry(
jax.random.PRNGKey(0), batch_dims=x.shape[:1], size=self.hidden_size
)
carry, x = LSTM()(carry, x)
return x
```
Slightly more complicated cases involving padding like in the [seq2seq](https://github.com/google/flax/blob/main/examples/seq2seq/models.py) example require even more work but couple potentially be simplified to a couple of lines with the right abstractions. We propose providing users with clean, correct, and efficient abstractions to use recurrent cells.

## Requirements

* **Masking**: We need to support a batch of sequences that contain padding at the end of each sequence.
* We do not intend to support non-contiguous padding, i.e. padding that is not at the end of a sequence, for performance reasons, except in the case of packing (see below).
* **Bidirectionality**: The ability to process a sequence in both the forward and reverse directions, respecting padding (i.e., the reverse direction should start with the actual inputs, not with padding values).
* **Performance**: The proposed classes should be benchmarked to provide the best performance in terms of step time and/or memory use.
* **Recurrent Dropout**: Support for recurrent dropout in cells (e.g. dropout on the state of the cell).

## Implementation
### High-level structure

We propose to have these 3 levels of abstraction:

* **Cells (unchanged)**: all RNNCellBase subclasses such as LSTMCell and GRUCell, these implement the stepwise logic. These already exist in Flax today.
* **Layers (new)**: a class (RNN) that takes a cell and scans over a sequence respecting possible padding values and optionally also allows packed sequences.
* **Bidirectional (new)**: a single class that takes a forward and a backward RNN instance and correctly processes the input sequence in both directions and merges the results.

### Example of proposed API
We start with a code example of what you could do with the proposed API, and then we discuss the API in detail below.

```python
cell = nn.LSTMCell()
# Encodes a batch of input sequences.
carry, outputs = nn.RNN(cell, cell_size)(inputs, seq_lengths)
```

A Bidirectional layer with a LSTM RNNs for the forward and backward directions respectively would look like this:

```python
forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
backward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, seq_lengths)
```

Next we will discuss `RNN`, `Bidirectional`, and proposed changes to `RNNCellBase`.

### RNNBase
The `RNNBase` class serves as a base class for the `RNN` class, it specifies
the API that all RNN layers should implement to be compatible with the `Bidirectional`.
`RNNBase` contains the `__call__` and `flip_sequences` methods:

```python
class RNNBase(Protocol):
def __call__(
self,
inputs: jax.Array,
*,
initial_carry: Optional[Carry] = None,
init_key: Optional[random.KeyArray] = None,
seq_lengths: Optional[Array] = None,
return_carry: Optional[bool] = None,
time_major: Optional[bool] = None,
reverse: Optional[bool] = None,
keep_order: Optional[bool] = None,
) -> Union[Output, Tuple[Carry, Output]]:
...
```
Where:

* `inputs`: the input sequence.
* `initial_carry`: the initial carry, if not provided it will be initialized
using the cell's :meth:`RNNCellBase.initialize_carry` method.
* `init_key`: a PRNG key used to initialize the carry, if not provided
``jax.random.PRNGKey(0)`` will be used. Most cells will ignore this
argument.
* `seq_lengths`: an optional integer array of shape ``(*batch)`` indicating
the length of each sequence, elements whose index in the time dimension
is greater than the corresponding length will be considered padding and
will be ignored.
* `return_carry`: if ``return_carry=False`` (default) only the output sequence is returned,
else it will return a tuple of the final carry and the output sequence.
* `time_major`: if ``time_major=False`` (default) it will expect inputs with shape
``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``.
* `reverse`: if ``reverse=False`` (default) the sequence is
processed from left to right and returned in the original order, else it will be processed
from right to left, and returned in reverse order. If ``seq_lengths`` is passed,
padding will always remain at the end of the sequence.
* `keep_order`: if ``keep_order=True``, when ``reverse=True``
the output will be reversed back to the original order after processing, this is
useful to align sequences in bidirectional RNNs. If ``keep_order=False`` (default),
the output will remain in the order specified by ``reverse``.
* `Returns`: if ``return_carry=False`` (default) only the output sequence is returned,
else it will return a tuple of the final carry and the output sequence.

### RNN
The `RNN` module inherits from `RNNBase`, it main function is to apply an `RNNCellBase` instance over a batch of input sequences, it can be used with any type of cell (e.g., `GRUCell`, `LSTMCell`, etc). It accepts the following parameters:

```python
class RNN(RNNBase):
cell: RNNCellBase,
cell_size: int | Tuple[int, ...]
time_axis: int = -2,
variable_axes = FrozenDict(),
variable_broadcast: CollectionFilter = 'params'
variable_carry: CollectionFilter = False
split_rngs = FrozenDict({'params': False})
# implement RNNBase
...
```

Attributes like `variable_axes`, `variable_broadcast`, `variable_carry`, and `split_rngs` are directly passed to `nn.scan`, their default values are set such that common cells like `LSTMCell` and `GRUCell` work out of the box.

### Masking
`seq_lengths` is defined as an integer array of shape `(*batch,)` indicating the length of each sequence.

<details><summary>Discussion</summary>

There are various masking formats found in other frameworks, here are some of the most popular ones:

* **Binary masking**: specifies per-sample and timestep whether that data point should be included or not in the computation, it can be non-contigous (e.g., [1, 1, 0, 1]). This is used by Keras.
* **Sequence length masking**: specifies per-sample the number of non-padding examples contained in the sequence, any padding contained in the sequence should be stacked at the end. This is used by FlaxFormer.
* **Segmentation Mask**: specifies row and timestep to which sample the data point belongs to, this format allows more than one sample per row which potentially reduces the total amount of padding needed (e.g. [1, 1, 1, 2, 2, 0, 0]). Pytorch uses this representation (see [pack_padded_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html)).

While Sequence packing (see [LM1B example](https://github.com/google/flax/blob/main/examples/lm1b/input_pipeline.py#L90-L92)) is is more powerful, its implementation is more complex and it is not clear whether it is worth the effort. The simplest format is sequence length masking, which is the one we propose to use.

</details>

### Bidirectional
Bidirectional processing can be achieved via a Module that accepts a `forward_rnn` Module and a `backward_rnn` Module, both of which should be `RNN` instances, in order to process the input sequence in both directions. Here we present some pseudo code of the implementation:

```python
def __call__(self, inputs, seq_lengths):
# Encode in the forward direction.
carry_forward, outputs_forward = self.forward_rnn(
inputs, seq_lengths=seq_lengths,
return_carry=True, reverse=False,
)
# Encode in the reverse order.
carry_backward, outputs_backward = self.backward_rnn(
inputs, seq_lengths=seq_lengths,
return_carry=True, reverse=True, # process in reverse order
keep_order=True, # but return the sequence in the original order
)
# Merge both sequences.
outputs = jax.tree_map(self.merge_fn, outputs_forward, outputs_backward)

return (carry_forward, carry_backward), outputs
```

Here `merge_fn` a function that takes both outputs and fuses them (`concat` by default). As showcased in the beginning of this document, usage would look like this:

```python
forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32)
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, seq_lengths)
```

### Recurrent Dropout
There are two main uses of dropout in RNNs:
1. Input dropout: regular dropout applied to the inputs, different for every step.
4. Recurrent dropout: applies dropout to a recurrent input/output, same for every step.

Flax's `nn.scan` can easily express both types of dropout via `split_rns`, input dropout would split rngs while recurrent dropout would not. [#2540](https://github.com/google/flax/pull/2540) was introduces such that the `rng_name` in `nn.Dropout` can now be defined by the user, this way Cells could define both types of dropout e.g:

```python
self.dropout = nn.Dropout(...) # input dropout
self.recurrent_dropout = nn.Dropout(..., rng_collection='recurrent_dropout')
```
Based on this, `nn.scan` / `nn.RNN` can now specify `split_rngs` accordingly e.g:
```
nn.scan(scan_fn, ..., split_rngs={'dropout': True, 'recurrent_dropout': False})
```

# Future ideas

<details><summary>show</summary>

### Sequence Packing
Allow packing multiple sequences to make efficient use of space/memory. This might result in a trade-off where step time is higher (because at each step we need to check whether we are starting a new sequence and reset the carry/initial state), but where less padding is used increasing efficiency overall.

### RNNCell redesign

#### Make initialize_state an instance method
First altenative is to make `initialize_carry` a instance method. With this change hyperparameters can be passed directly to the cell, it signature would look like this:

```python
def initialize_carry(self, sample_input) -> Carry:
...
```

Usage would look like this:

```python
LSTM = nn.scan(
nn.LSTMCell, variable_broadcast='params',
split_rngs={'dropout': True})
lstm = LSTM(features=32)
carry = lstm.initialize_carry(x[:, 0])
carry, y = lstm(carry, x)
```

#### Remove initialize_carry

An alternative is to remove `initialize_carry` entirely and have the carry state be handled as a carry collection. This would simplify usage quite a bit:

```python
LSTM = nn.scan(
nn.LSTMCell, variable_broadcast='params',
split_rngs={'dropout': True})
y = LSTM(features=32)(carry, x)
```

However, this would require `nn.scan` to support initialization of carry collections which is currently not possible. Also, users would have to specify that a collection is mutable e.g. `mutable=['carry']`, even if they are not interested in the output carry state.

</details>