Skip to content

Commit d6fb012

Browse files
author
Flax Authors
committed
Merge pull request #2540 from cgarciae:parametrize-dropout-rng-collection
PiperOrigin-RevId: 481925867
2 parents 376804f + d7b3c12 commit d6fb012

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

flax/linen/stochastic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ class Dropout(Module):
3838
deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
3939
masked, whereas if true, no mask is applied and the inputs are returned
4040
as is.
41+
rng_collection: the rng collection name to use when requesting an rng key.
4142
"""
4243
rate: float
4344
broadcast_dims: Sequence[int] = ()
4445
deterministic: Optional[bool] = None
46+
rng_collection: str = 'dropout'
4547

4648
@compact
4749
def __call__(self, inputs, deterministic: Optional[bool] = None):
@@ -67,7 +69,7 @@ def __call__(self, inputs, deterministic: Optional[bool] = None):
6769
if deterministic:
6870
return inputs
6971
else:
70-
rng = self.make_rng('dropout')
72+
rng = self.make_rng(self.rng_collection)
7173
broadcast_shape = list(inputs.shape)
7274
for dim in self.broadcast_dims:
7375
broadcast_shape[dim] = 1

0 commit comments

Comments
 (0)