Skip to content

kv-cache : separate recurrent vs non-recurrent impl #12799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 2, 2025

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Apr 7, 2025

Overview

Attempting to make two separate classes for the 2 types of KV cache:

  • llama_kv_cache_unified : llama_kv_cache
  • llama_kv_cache_recurrent : llama_kv_cache
graph TD;
llama_memory_i --> llama_kv_cache
llama_kv_cache --> llama_kv_cache_unified
llama_kv_cache --> llama_kv_cache_recurrent
Loading

The main goal of this change is to simplify the logic in the primary llama_kv_cache_unified class so that we can more easily extend it with new features such as SWA. Also to introduce a certain level of abstraction that would allow to add new types of KV cache implementations in the future.

Main changes

  • The llama_context now operates with the abstract llama_memory_i interface.

  • Add llama_memory_params and use it to implement llama_model::create_memory() for creating model-specific cache

  • llama_kv_cache_recurrent is currently mostly a copy of llama_kv_cache_unified, but should be now completely separated and a new recurrent-specific implementation can be done

  • Move KV cache shift and defrag code from llama_context to llama_kv_cache_unified

  • The llama_sbatch -> llama_ubatch logic inside llama_context:decode() is now implemented by:

    • llama_kv_cache::sbatch_init()
    • llama_kv_cache::ubatch_next()

    The thinking is that certain KV cache implementation could require different types of micro-batching (e.g. same-sequence-length ubatch, single-sequence ubatch, etc.)

  • Remove llama_context::output_reorder() - seemed to be relevant only for recurrent caches. We now have inlined the logic in llama_context:decode()

  • Remove llama_context::sbatch. Instead, create a new one for each decode

TODO before merge

  • Clean-up llama_kv_cache interface
  • Make llama_kv_cache_xxx more private
  • Add comments

Next PRs

  • Support cache-less context (needed for embedding-only models such as BERT) (context : allow cache-less context for embeddings #13108)
  • Remove llama_context_params.logits_all logic - unnecessary complication, can be achieved with explicit request for logits for all tokens
  • Remove infill example - obsolete
  • Add proper SWA support to llama_kv_cache_unified

Resolve

@slaren
Copy link
Member

slaren commented Apr 29, 2025

What the reasoning for using llama_kv_cache as the base class for llama_kv_cache_recurrent? Is there enough code shared between these types to justify this? It seems that there is a lot of complexity in llama_kv_cache_recurrent and it would be good if that could be simplified a bit.

On a more general note, I think it is not very usual the way std::function callbacks are mixed with inheritance. I think the more typical way to do this would be to create virtual functions that can be overriden in a child class. I wonder if I am missing something here that would prevent implementing this in this way.

@ggerganov
Copy link
Member Author

What the reasoning for using llama_kv_cache as the base class for llama_kv_cache_recurrent? Is there enough code shared between these types to justify this? It seems that there is a lot of complexity in llama_kv_cache_recurrent and it would be good if that could be simplified a bit.

The public API currently works with struct llama_kv_cache *, so both the recurrent and non-recurrent implementation have to implement it for now.

I think what we need to do in a follow-up PR is:

  • Deprecate the public API llama_kv_cache_
  • Add llama_memory_ API that works with struct llama_memory

At this point, a completely new recurrent-specific implementation can be added: class llama_memory_recurrent : public llama_memory_i that would replace the current llama_kv_cache_recurrent.

The existing recurrent cache implementation has to be rewritten from scratch, because is was hacked on top of the KV cache implementation by repurposing the K and V tensors for the state space requirements.


On a more general note, I think it is not very usual the way std::function callbacks are mixed with inheritance. I think the more typical way to do this would be to create virtual functions that can be overridden in a child class. I wonder if I am missing something here that would prevent implementing this in this way.

I'll try to update this. Just to make sure, you mean the current:

  • struct llama_kv_cache::callbacks
  • struct llama_kv_cache::graph_params

to become interfaces with different implementations based on the type of memory?

@slaren
Copy link
Member

slaren commented Apr 29, 2025

I don't fully understand the code, but I think get_rope_factors and get_buft could be virtual/abstract functions of llama_kv_cache, and a new class can be created if these need a different implementations. But I am not sure that's the case, maybe they can just be regular functions and llama_kv_cache should just have a reference to the llama_model. graph_params looks like it should be an interface, but it feels out of place.

@compilade
Copy link
Collaborator

compilade commented Apr 29, 2025

The public API currently works with struct llama_kv_cache *, so both the recurrent and non-recurrent implementation have to implement it for now.

There will need to be some top-level type which can contain multiple types of KV caches to ease supporting hybrid models. A shared interface for recurrent and non-recurrent state caches is useful to get to that point, at least for maintainability.

The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way). That is relevant mostly for hybrid models, though.

The existing recurrent cache implementation has to be rewritten from scratch, because is was hacked on top of the KV cache implementation by repurposing the K and V tensors for the state space requirements.

Yes it will need to be rewritten at least to be able to support proper state rollback.

But even if it was repurposing the K and V tensors, there are still some things which I think will remain, since Mamba and RWKV do have 2 types of recurrent states per layer.

@ggerganov ggerganov force-pushed the gg/llama-kv-cache-v6 branch 2 times, most recently from e37f112 to 7e4b545 Compare April 30, 2025 07:22
@ggerganov
Copy link
Member Author

@slaren In 7e4b545 I replaced the struct callbacks by maintaining a reference of llama_model in the llama_kv_cache implementation. And in 73df685 I replaced the struct graph_params by passing a reference to the llama_context.

PTAL if you think these changes are good.

@ggerganov ggerganov force-pushed the gg/llama-kv-cache-v6 branch from 73df685 to eb623f2 Compare April 30, 2025 08:30
Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good. While testing this, I noticed that the KV cache is always allocated on the CPU.


//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
kv_cell & cell = const_cast<kv_cell &>(cells[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kv_cell & cell = const_cast<kv_cell &>(cells[i]);
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);

Otherwise multi-user inference is broken for recurrent models. See #9126 (comment).


//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
kv_cell & cell = const_cast<kv_cell &>(cells[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kv_cell & cell = const_cast<kv_cell &>(cells[i]);
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);

Same, this should fix multi-user inference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a small multi-user test with a recurrent model to server/tests to be able to spot such regressions.

@ggerganov ggerganov force-pushed the gg/llama-kv-cache-v6 branch from 780d6fb to 58115a2 Compare May 2, 2025 10:28
@ggerganov ggerganov force-pushed the gg/llama-kv-cache-v6 branch from 58115a2 to 7e79a42 Compare May 2, 2025 13:02
@ggerganov
Copy link
Member Author

@slaren @compilade I think this should be good to merge - any additional comments?

@ggerganov
Copy link
Member Author

There will need to be some top-level type which can contain multiple types of KV caches to ease supporting hybrid models. A shared interface for recurrent and non-recurrent state caches is useful to get to that point, at least for maintainability.

The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way). That is relevant mostly for hybrid models, though.

I think that when we introduce the llama_memory_ API (see #12799 (comment)) we can redesign how the caches are used. The existing llama_kv_cache_seq_ API is not great in general (error prone and a bit hacky to use), so it would be a good opportunity to think about ways to simplify and improve it.

@ggerganov ggerganov merged commit c642bc0 into master May 2, 2025
1 check passed
@ggerganov ggerganov deleted the gg/llama-kv-cache-v6 branch May 2, 2025 14:48
Comment on lines +1069 to +1070
// make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also only relevant when using get_embeddings, because the buffer in that case has to be ordered to keep the API backward compatible. When purely using get_embeddings_ith, it's not required.

Unconditionally sorting is unnecessary and is likely slower. Also it seems like some assertions here break multi-user inference for recurrent models (since the line right after this block where n_outputs = n_outputs_all is assumed to have run before the sorting routine, but it hasn't).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason to decide to always reorder is because otherwise we have to maintain the sbatch in the state of the context. This introduces some complexity that is hard to reason around so I decided to take the hit.

We should add a test that exercises this branch. What is a server scenario that would trigger the reordering?

I'll PR the n_outputs = n_outputs_all before the sorting fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants