Skip to content

Commit 900015a

Browse files
committed
Adding more documentation to Dropout around rng use
1 parent d31bd1a commit 900015a

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

flax/linen/stochastic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
class Dropout(Module):
2828
"""Create a dropout layer.
2929
30+
Note: When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
31+
to include an RNG seed named `'dropout'`. For example::
32+
33+
model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})`
34+
3035
Attributes:
3136
rate: the dropout probability. (_not_ the keep rate!)
3237
broadcast_dims: dimensions that will share the same dropout mask

0 commit comments

Comments
 (0)