Large Model output (bsz x seq_len x vocab_size) repeated 3 times in memory with jax.value_and_grad(). #4097
Unanswered
saschafrey
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm training a model with a custom vocabulary of a few thousand tokens. Of course, this brings with it certain memory constraints. What I am struggling with, is why the expected output of the model (see title) is being represented 3 times in the GPU memory during my training step. Is there a way to work around this? As it does impact the model/batch size I can load onto a given GPU quite significantly. TB-profiler traces and memory views are attached, the XLA OPS which produce the 3 large tensors are highlighted in red. Happy to provide more details.
I have already tried using the jax.checkpoint() decorator on modules other than the final output layer ("decoder"), with no major effect.
Beta Was this translation helpful? Give feedback.
All reactions