-
Notifications
You must be signed in to change notification settings - Fork 716
Add Stateful Modules guide #2689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report
@@ Coverage Diff @@
## main #2689 +/- ##
=======================================
Coverage 81.15% 81.15%
=======================================
Files 51 51
Lines 5492 5493 +1
=======================================
+ Hits 4457 4458 +1
Misses 1035 1035
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
docs/guides/stateful_modules.rst
Outdated
* Show how to use ``bind`` to manually run a ``Sequential`` module layer | ||
by layer. | ||
|
||
#. Edge cases of using ``bind``: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add this to the "sharp bits" as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: @8bitmp3
docs/guides/stateful_modules.rst
Outdated
|
||
Contents | ||
********* | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would first explain that Flax Modules normally are stateless since Flax is using a functional API, and explain why this is the case. Then say that in some cases it can actually be useful to attach variables to the modules, and that the bind/unbind pattern allows this.
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Currently import flax.linen as nn
import jax
import jax.numpy as jnp
module = nn.Sequential([
nn.Sequential([
nn.Dense(8),
nn.relu,
nn.Dense(4),
]),
nn.relu,
nn.Dense(4),
])
x = jnp.ones((1, 5))
variables = module.init(jax.random.PRNGKey(0), x)
bound_module = module.bind(variables)
print(bound_module.layers[0].layers[0].scope) # None <== !!! @levskaya will try to come up with a solution to bind submodules, this guide will have to wait for that to be resolved. Previous attempt to solve this: #2028 |
73cf1e9
to
687cbf6
Compare
What does this PR do?
WIP
Live Preview: https://flax--2689.org.readthedocs.build/en/2689/guides/stateful_modules.html