Skip to content

Commit 6b80cbb

Browse files
author
Flax Authors
committed
Merge pull request #2510 from dslisleedh:fix_dropout
PiperOrigin-RevId: 482103869
2 parents c628910 + 98453e5 commit 6b80cbb

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Run the hooks on all files with
88
# 'pre-commit run --all'
99

10+
repos:
1011
- repo: https://github.com/mwouts/jupytext
1112
rev: v1.13.8
1213
hooks:

flax/linen/stochastic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@ def __call__(self, inputs, deterministic: Optional[bool] = None):
6060
"""
6161
deterministic = merge_param(
6262
'deterministic', self.deterministic, deterministic)
63-
if self.rate == 0.:
63+
64+
if (self.rate == 0.) or deterministic:
6465
return inputs
66+
6567
# Prevent gradient NaNs in 1.0 edge-case.
6668
if self.rate == 1.0:
6769
return jnp.zeros_like(inputs)
70+
6871
keep_prob = 1. - self.rate
69-
if deterministic:
70-
return inputs
71-
else:
72-
rng = self.make_rng(self.rng_collection)
73-
broadcast_shape = list(inputs.shape)
74-
for dim in self.broadcast_dims:
75-
broadcast_shape[dim] = 1
76-
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
77-
mask = jnp.broadcast_to(mask, inputs.shape)
78-
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
72+
rng = self.make_rng(self.rng_collection)
73+
broadcast_shape = list(inputs.shape)
74+
for dim in self.broadcast_dims:
75+
broadcast_shape[dim] = 1
76+
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
77+
mask = jnp.broadcast_to(mask, inputs.shape)
78+
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))

0 commit comments

Comments
 (0)