-
Notifications
You must be signed in to change notification settings - Fork 11.6k
kv-cache : separate recurrent vs non-recurrent impl #12799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e4a626a
to
d953616
Compare
ed8942a
to
2c3547e
Compare
7414574
to
d31e31d
Compare
b37b295
to
dec80ac
Compare
66f1ba6
to
65cde6d
Compare
What the reasoning for using On a more general note, I think it is not very usual the way |
The public API currently works with I think what we need to do in a follow-up PR is:
At this point, a completely new recurrent-specific implementation can be added: The existing recurrent cache implementation has to be rewritten from scratch, because is was hacked on top of the KV cache implementation by repurposing the K and V tensors for the state space requirements.
I'll try to update this. Just to make sure, you mean the current:
to become interfaces with different implementations based on the type of memory? |
I don't fully understand the code, but I think |
There will need to be some top-level type which can contain multiple types of KV caches to ease supporting hybrid models. A shared interface for recurrent and non-recurrent state caches is useful to get to that point, at least for maintainability. The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way). That is relevant mostly for hybrid models, though.
Yes it will need to be rewritten at least to be able to support proper state rollback. But even if it was repurposing the K and V tensors, there are still some things which I think will remain, since Mamba and RWKV do have 2 types of recurrent states per layer. |
e37f112
to
7e4b545
Compare
73df685
to
eb623f2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good. While testing this, I noticed that the KV cache is always allocated on the CPU.
src/llama-kv-cache.cpp
Outdated
|
||
////////////////////////////////////////////// | ||
// TODO: this should not mutate the KV cache ! | ||
kv_cell & cell = const_cast<kv_cell &>(cells[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kv_cell & cell = const_cast<kv_cell &>(cells[i]); | |
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]); |
Otherwise multi-user inference is broken for recurrent models. See #9126 (comment).
src/llama-kv-cache.cpp
Outdated
|
||
////////////////////////////////////////////// | ||
// TODO: this should not mutate the KV cache ! | ||
kv_cell & cell = const_cast<kv_cell &>(cells[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kv_cell & cell = const_cast<kv_cell &>(cells[i]); | |
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]); |
Same, this should fix multi-user inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a small multi-user test with a recurrent model to server/tests
to be able to spot such regressions.
ggml-ci
ggml-ci
ggml-ci
ggml-ci
780d6fb
to
58115a2
Compare
58115a2
to
7e79a42
Compare
@slaren @compilade I think this should be good to merge - any additional comments? |
I think that when we introduce the |
// make the outputs have the same order they had in the user-provided batch | ||
// note: this is mostly relevant for recurrent models atm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also only relevant when using get_embeddings
, because the buffer in that case has to be ordered to keep the API backward compatible. When purely using get_embeddings_ith
, it's not required.
Unconditionally sorting is unnecessary and is likely slower. Also it seems like some assertions here break multi-user inference for recurrent models (since the line right after this block where n_outputs = n_outputs_all
is assumed to have run before the sorting routine, but it hasn't).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason to decide to always reorder is because otherwise we have to maintain the sbatch
in the state of the context. This introduces some complexity that is hard to reason around so I decided to take the hit.
We should add a test that exercises this branch. What is a server scenario that would trigger the reordering?
I'll PR the n_outputs = n_outputs_all
before the sorting fix.
Overview
Attempting to make two separate classes for the 2 types of KV cache:
llama_kv_cache_unified : llama_kv_cache
llama_kv_cache_recurrent : llama_kv_cache
The main goal of this change is to simplify the logic in the primary
llama_kv_cache_unified
class so that we can more easily extend it with new features such as SWA. Also to introduce a certain level of abstraction that would allow to add new types of KV cache implementations in the future.Main changes
The
llama_context
now operates with the abstractllama_memory_i
interface.Add
llama_memory_params
and use it to implementllama_model::create_memory()
for creating model-specific cachellama_kv_cache_recurrent
is currently mostly a copy ofllama_kv_cache_unified
, but should be now completely separated and a new recurrent-specific implementation can be doneMove KV cache shift and defrag code from
llama_context
tollama_kv_cache_unified
The
llama_sbatch
->llama_ubatch
logic insidellama_context:decode()
is now implemented by:llama_kv_cache::sbatch_init()
llama_kv_cache::ubatch_next()
The thinking is that certain KV cache implementation could require different types of micro-batching (e.g. same-sequence-length ubatch, single-sequence ubatch, etc.)
Remove
llama_context::output_reorder()
- seemed to be relevant only for recurrent caches. We now have inlined the logic inllama_context:decode()
Remove
llama_context::sbatch
. Instead, create a new one for each decodeTODO before merge
llama_kv_cache
interfacellama_kv_cache_xxx
more privateNext PRs
llama_context_params.logits_all
logic - unnecessary complication, can be achieved with explicit request for logits for all tokensinfill
example - obsoletellama_kv_cache_unified
Resolve