Skip to content

Fix flax.linen.stochastic.Dropout #2510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Run the hooks on all files with
# 'pre-commit run --all'

repos:
- repo: https://github.com/mwouts/jupytext
rev: v1.13.8
hooks:
Expand Down
22 changes: 11 additions & 11 deletions flax/linen/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ def __call__(self, inputs, deterministic: Optional[bool] = None):
"""
deterministic = merge_param(
'deterministic', self.deterministic, deterministic)
if self.rate == 0.:

if (self.rate == 0.) or deterministic:
return inputs

# Prevent gradient NaNs in 1.0 edge-case.
if self.rate == 1.0:
return jnp.zeros_like(inputs)

keep_prob = 1. - self.rate
if deterministic:
return inputs
else:
rng = self.make_rng(self.rng_collection)
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
rng = self.make_rng(self.rng_collection)
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))