Closed
Description
Currently 'dropout'
is hard-coded as an argument of make_rng
inside Dropout
. However, when implementing support for "recurrent dropout" in an LSTMCell or similar you need two kinds of dropout:
- A regular dropout which is applied to the inputs with a different mask at each step.
- A "recurrent dropout" that is applied to the state with the same mask at each step.
To implement 2
a possibility is to set the RNG name to 'recurrent_dropout'
on the Dropout layer applied to the state and guarantee that each step uses the same random state. From nn.scan
's perspective the would look like this:
nn.scan(..., split_rngs={'dropout': True, 'recurrent_dropout': False})
The proposal is to add an rng_name
(or similar) attribute to Dropout
so we are able support these kind of use-cases. The alternative would be to create a separate RecurrentDropout
layer with the same code but different hard-coded value.