RNG count update during Manual Training #4665
Unanswered
DiagRisker
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
If you wanted to know how to update the RNG State count during manual Training. Here is an example (based on documentation MLP)
We can see that no count change occured for the gradient pass. This is troublesome for mini batch processing if dropout does not update its rng count during each batch. How to force it to update? -> jax.tree.map is an option but it does not directly work with nnx Filters!
Based on https://flax.readthedocs.io/en/latest/guides/randomness.html, you can write a shortcut :
This is a simple and useful way to control how rng goes during the training
If the rng state have different update methods (than count +=1 ), this will not solve the problem.
In this case, I leave contributors the option to complete this case
Beta Was this translation helpful? Give feedback.
All reactions