Skip to content

Make RNG name configurable in Dropout #2194

Closed
@cgarciae

Description

@cgarciae

Currently 'dropout' is hard-coded as an argument of make_rng inside Dropout. However, when implementing support for "recurrent dropout" in an LSTMCell or similar you need two kinds of dropout:

  1. A regular dropout which is applied to the inputs with a different mask at each step.
  2. A "recurrent dropout" that is applied to the state with the same mask at each step.

To implement 2 a possibility is to set the RNG name to 'recurrent_dropout' on the Dropout layer applied to the state and guarantee that each step uses the same random state. From nn.scan's perspective the would look like this:

nn.scan(..., split_rngs={'dropout': True, 'recurrent_dropout': False})

The proposal is to add an rng_name (or similar) attribute to Dropout so we are able support these kind of use-cases. The alternative would be to create a separate RecurrentDropout layer with the same code but different hard-coded value.

Metadata

Metadata

Assignees

Labels

Priority: P2 - no scheduleBest effort response and resolution. We have no plan to work on this at the moment.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions