Skip to content

Fix flax.linen.stochastic.Dropout #2510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 19, 2022
Merged

Conversation

dslisleedh
Copy link
Contributor

@dslisleedh dslisleedh commented Oct 6, 2022

What does this PR do?

Current flax.linen.stochastic.Dropout returns zeros_like(inputs) if drop rate is 1 even though deterministic argument is set True.

import flax.linen as nn
import jax
import jax.numpy as jnp

class Test(nn.Module):
    
    @nn.compact
    def __call__(self, x, deterministic: bool = False):
        x = nn.Dense(10)(x)
        x = nn.Dropout(rate=1.)(x, deterministic=deterministic)
        return x

test_model = Test()
rng = jax.random.PRNGKey(42)
outputs, params = test_model.init_with_output(rng, jnp.ones((10, 10)), deterministic=False)

print(jnp.sum(outputs))

rng, _ = jax.random.split(rng)

outputs_deterministic = test_model.apply(params, jnp.ones((10, 10)), deterministic=True, rngs={'dropout': rng})

print(jnp.sum(outputs_deterministic))
print(params)

0.0
0.0
FrozenDict({
    params: {
        Dense_0: {
            kernel: DeviceArray([[ 0.16847129, -0.1178944 ,  0.13583411,  0.16801184,
                          -0.24065216,  0.01868828,  0.4807749 , -0.30601937,
                           0.32568878, -0.08822519],
                         [ 0.36135912,  0.1638065 , -0.28575695, -0.031698  ,
                          -0.3977836 , -0.17761624,  0.1090662 ,  0.17428328,
                           0.12361519,  0.70224416],
                         [-0.4673121 , -0.19664605, -0.2457073 ,  0.03010363,
                           0.02082355, -0.3135972 ,  0.21568213,  0.23018655,
                          -0.00779009,  0.17412186],
                         [ 0.34069905,  0.00373967, -0.01591648,  0.15487881,
                           0.07604861, -0.25911337,  0.04083756, -0.08719958,
                           0.35543105, -0.11303028],
                         [ 0.35775012, -0.52226734, -0.45758778, -0.46874872,
                          -0.20843479,  0.378424  ,  0.27923864, -0.21878792,
                          -0.08861114, -0.21674518],
                         [ 0.5622318 , -0.18204133,  0.18515447, -0.03003849,
                          -0.24391893, -0.05861915, -0.71287733, -0.55168146,
                           0.28309464,  0.21501313],
                         [ 0.3712106 ,  0.5399647 , -0.21155106, -0.0811978 ,
                          -0.26425117, -0.09787431, -0.3940018 ,  0.27078357,
                          -0.18498448,  0.22094563],
                         [ 0.47136635, -0.27565262, -0.47951284,  0.11183885,
                           0.4013355 ,  0.11647745, -0.46050343, -0.20171496,
                           0.5458673 , -0.38525596],
                         [-0.696714  , -0.1420307 , -0.31419477, -0.37206405,
                          -0.07097312, -0.26630354,  0.40364647, -0.01666957,
                           0.19186951,  0.01891448],
                         [-0.15773533, -0.01117601,  0.0558053 ,  0.05353177,
                          -0.46538368,  0.3182615 , -0.38293365, -0.04739544,
                          -0.02096457,  0.31650487]], dtype=float32),
            bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
        },
    },
})

But I believe Dropout with rate 1. should return output calculated with initialized weights not zeros_like(outputs) when validate model.
So I reordered some condition operators.

same test code, but different result.

0.0
-20.697302

Below is test result with changed codes.

(flax_test) idongheon@idongheon-ui-MacBookAir flax % python ./tests/linen/linen_test.py
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN      ] IdsTest.test_hashable
[       OK ] IdsTest.test_hashable
[ RUN      ] NormalizationTest.test_batch_norm
I1006 18:52:34.324705 4309908800 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I1006 18:52:34.325006 4309908800 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I1006 18:52:34.325146 4309908800 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1006 18:52:34.325252 4309908800 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1006 18:52:34.325681 4309908800 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[       OK ] NormalizationTest.test_batch_norm
[ RUN      ] NormalizationTest.test_batch_norm_complex
[       OK ] NormalizationTest.test_batch_norm_complex
[ RUN      ] NormalizationTest.test_batch_norm_multi_init
[       OK ] NormalizationTest.test_batch_norm_multi_init
[ RUN      ] NormalizationTest.test_group_norm
[       OK ] NormalizationTest.test_group_norm
[ RUN      ] NormalizationTest.test_group_norm_raises
[       OK ] NormalizationTest.test_group_norm_raises
[ RUN      ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[       OK ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ RUN      ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[       OK ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ RUN      ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[       OK ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ RUN      ] PoolTest.test_avg_pool0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[ RUN      ] PoolTest.test_max_pool
[       OK ] PoolTest.test_max_pool
[ RUN      ] PoolTest.test_pool_custom_reduce
[       OK ] PoolTest.test_pool_custom_reduce
[ RUN      ] RecurrentTest.test_complex_input_gru
[       OK ] RecurrentTest.test_complex_input_gru
[ RUN      ] RecurrentTest.test_convlstm
[       OK ] RecurrentTest.test_convlstm
[ RUN      ] RecurrentTest.test_gru
[       OK ] RecurrentTest.test_gru
[ RUN      ] RecurrentTest.test_lstm
[       OK ] RecurrentTest.test_lstm
[ RUN      ] RecurrentTest.test_optimized_lstm_cell_matches_regular
/Users/idongheon/miniforge3/envs/flax_test/lib/python3.10/site-packages/jax/test_util.py:44: FutureWarning: jax.test_util.check_eq is deprecated and will soon be removed.
  warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
[       OK ] RecurrentTest.test_optimized_lstm_cell_matches_regular
[ RUN      ] StochasticTest.test_dropout
[       OK ] StochasticTest.test_dropout
[ RUN      ] StochasticTest.test_dropout_rate_limits
[       OK ] StochasticTest.test_dropout_rate_limits
[ RUN      ] StochasticTest.test_dropout_rate_stats
[       OK ] StochasticTest.test_dropout_rate_stats
----------------------------------------------------------------------
Ran 25 tests in 14.817s

OK

Thank you.

Fixes # (issue)

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in a Github issue/
    discussion (please add a
    link).
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

@codecov-commenter
Copy link

codecov-commenter commented Oct 6, 2022

Codecov Report

Merging #2510 (af12949) into main (61e47c2) will increase coverage by 0.42%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #2510      +/-   ##
==========================================
+ Coverage   79.04%   79.46%   +0.42%     
==========================================
  Files          49       49              
  Lines        5183     5202      +19     
==========================================
+ Hits         4097     4134      +37     
+ Misses       1086     1068      -18     
Impacted Files Coverage Δ
flax/linen/stochastic.py 96.29% <100.00%> (-0.26%) ⬇️
flax/linen/module.py 92.71% <0.00%> (-0.03%) ⬇️
flax/errors.py 87.36% <0.00%> (+0.15%) ⬆️
flax/traverse_util.py 99.01% <0.00%> (+0.49%) ⬆️
flax/struct.py 78.46% <0.00%> (+1.04%) ⬆️
flax/training/checkpoints.py 66.17% <0.00%> (+5.24%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@cgarciae
Copy link
Collaborator

@dslisleedh looks good! pre-commit is failing, can you run pre-commit and push again?

pip install pre-commit
pre-commit install
pre-commit run --all-files

@dslisleedh
Copy link
Contributor Author

@cgarciae Thank you for your help :) Passed all checks

@copybara-service copybara-service bot merged commit 6b80cbb into google:main Oct 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants