Recommend way to use jax.jit for generation from transformers #6242
Unanswered
davisyoshida
asked this question in
Q&A
Replies: 3 comments 18 replies
-
For now I've decided to convert my JAX weights back to TF after training, as that seems to be the easiest way to get quick generation |
Beta Was this translation helpful? Give feedback.
0 replies
-
Perhaps this https://github.com/google/flax/blob/master/examples/lm1b/temperature_sampler.py#L27? |
Beta Was this translation helpful? Give feedback.
13 replies
-
@davisyoshida Based on our discussion, I have some more ideas:
|
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I've implemented GPT-2 in JAX, but unfortunately generation is currently prohibitively slow, since
jax.jit
will recompile at each token during the first run, since an increasing amount of cached hidden states are passed in each time. Of course subsequent generations are faster, but on my machine this amounts to almost 2 hours of "warmup" time when trying to run one of these models (~6.6 sec/token * 1024 tokens). I was also unable to avoid this using any of the loop constructs orjax.lax.scan
, since the carry will change shape.The best solution I've thought of is to select a max length a priori, pad all the cached hidden states up to that length outside of the JIT-ed part of the code, then mask the junk computations that result in the network. I'd prefer to avoid doing so, as it both wastes computation and is much less clean.
In order to make clear exactly what I'm talking about, here's a small example which has a similar issue:
Beta Was this translation helpful? Give feedback.
All reactions