Skip to content

Refactor KV cache manager #1315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions mistralrs-core/src/kv_cache/full_cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::sync::{Arc, Mutex, MutexGuard};

use candle_core::Tensor;

use super::{Cache, NormalCache};

pub type LayerCaches = Vec<Option<(Tensor, Tensor)>>;

#[derive(Debug, Clone)]
pub enum EitherCache {
Normal(Arc<Mutex<NormalCache>>),
Full(Cache),
}

impl EitherCache {
/// Panics otherwise!
pub fn full(&self) -> &Cache {
match self {
Self::Full(full) => full,
Self::Normal(_) => panic!("Got normal cache, expected full cache."),
}
}

/// Panics otherwise!
pub fn normal(&self) -> MutexGuard<'_, NormalCache> {
match self {
Self::Normal(normal) => normal.lock().unwrap(),
Self::Full(_) => panic!("Got full cache, expected normal cache."),
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,19 @@ use std::sync::{Arc, Mutex, MutexGuard};

use candle_core::{Result, Tensor, D};

use crate::{get_mut_arcmutex, sequence::Sequence};
use crate::{
get_mut_arcmutex,
pipeline::{CacheManagerMixin, MetadataMixin},
sequence::Sequence,
};

use super::{CacheManagerMixin, MetadataMixin};
mod full_cache;
mod rotating_cache;
mod single_cache;

pub use full_cache::{EitherCache, LayerCaches};
pub use rotating_cache::RotatingCache;
pub use single_cache::SingleCache;

pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
fn clone_in_cache(
Expand All @@ -23,278 +33,6 @@ pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
);
}

pub type LayerCaches = Vec<Option<(Tensor, Tensor)>>;

#[derive(Debug, Clone)]
pub enum EitherCache {
Normal(Arc<Mutex<NormalCache>>),
Full(Cache),
}

impl EitherCache {
/// Panics otherwise!
pub fn full(&self) -> &Cache {
match self {
Self::Full(full) => full,
Self::Normal(_) => panic!("Got normal cache, expected full cache."),
}
}
/// Panics otherwise!
pub fn normal(&self) -> MutexGuard<'_, NormalCache> {
match self {
Self::Normal(normal) => normal.lock().unwrap(),
Self::Full(_) => panic!("Got full cache, expected normal cache."),
}
}
}

#[derive(Debug, Clone)]
pub struct SingleCache {
// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
// on the first call where the batch size is easily known.
// Also this makes it safe to clone a KvCache that has been reset (as in it will not share
// its internal state with the cloned instance).
pub all_data: Option<Tensor>,
pub dim: usize,
pub current_seq_len: usize,
pub capacity_seq_len: usize,
pub max_seq_len: usize,
}

impl SingleCache {
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
current_seq_len: 0,
max_seq_len,
capacity_seq_len,
}
}

pub fn dim(&self) -> usize {
self.dim
}

pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}

pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}

pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}

pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
};
Ok(data)
}

pub fn reset(&mut self) {
self.current_seq_len = 0;
self.all_data = None;
}

pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
self.current_seq_len = len;
Ok(())
}

pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad);
};

// Expand kv cache
if self.current_seq_len + seq_len > self.capacity_seq_len {
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
if self.capacity_seq_len > self.max_seq_len {
candle_core::bail!(
"kv-cache: requested capacity ({}) above max seq len ({})",
self.capacity_seq_len,
self.max_seq_len
)
}
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
self.all_data = Some(ad);
}

let ad = self.all_data.as_mut().unwrap();

ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
}

#[derive(Debug, Clone)]
pub struct RotatingCache {
pub all_data: Option<Tensor>,
pub dim: usize,
// `offset` is the current write index in the buffer
pub offset: usize,
// The total size of the sequence seen so far.
pub current_seq_len: usize,
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
// sequence to grow past this limit.
pub max_seq_len: usize,
pub capacity_seq_len: usize,
}

impl RotatingCache {
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
offset: 0,
current_seq_len: 0,
max_seq_len,
capacity_seq_len,
}
}

pub fn offset(&self) -> usize {
self.offset
}

pub fn dim(&self) -> usize {
self.dim
}

pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}

pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}

pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}

pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => {
if self.current_seq_len >= self.max_seq_len {
Some(d.clone())
} else {
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
}
}
};
Ok(data)
}

pub fn reset(&mut self) {
self.offset = 0;
self.current_seq_len = 0;
self.all_data = None;
}

pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
// If trying to roll it back past the boundary of max_seq_len, fail early.
if self.current_seq_len - len > self.max_seq_len {
candle_core::bail!(
"Rotating KV cache (usually for sliding window) tried to reset to len {len} while current is {} and max retained is {}",
self.current_seq_len,
self.max_seq_len
);
}
self.current_seq_len = len;
self.offset = len % self.max_seq_len;
Ok(())
}

pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};

// Expand kv cache, this case is a little more complex.
if (self.current_seq_len + seq_len > self.capacity_seq_len
&& self.current_seq_len + seq_len < self.max_seq_len)
|| self.current_seq_len == 0
{
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
self.capacity_seq_len = self.capacity_seq_len.min(self.max_seq_len);
if self.capacity_seq_len > self.max_seq_len {
candle_core::bail!(
"kv-cache: requested capacity ({}) above max seq len ({})",
self.capacity_seq_len,
self.max_seq_len
)
}
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
self.all_data = Some(ad);
}

let ad = self.all_data.as_mut().unwrap();

self.current_seq_len += seq_len;
if seq_len >= self.max_seq_len {
let to_copy = src
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
.contiguous()?;
ad.slice_set(&to_copy, self.dim, 0)?;
self.offset = 0;
// Here we return `src` rather than `ad` so that all the past can be used.
Ok(src.clone())
} else {
let rem_len = self.max_seq_len - self.offset;
if seq_len <= rem_len {
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
self.offset = (self.offset + seq_len) % self.max_seq_len;
} else {
// We have to make two copies here as we go over the boundary of the cache.
if rem_len > 0 {
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
ad.slice_set(&src1, self.dim, self.offset)?;
}
let src2 = src
.narrow(self.dim, rem_len, seq_len - rem_len)?
.contiguous()?;
ad.slice_set(&src2, self.dim, 0)?;
self.offset = seq_len - rem_len;
}
if self.current_seq_len >= self.max_seq_len {
Ok(ad.clone())
} else {
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
}
}
}
}

#[derive(Debug, Clone)]
pub enum KvCache {
Normal { k: SingleCache, v: SingleCache },
Expand Down
Loading
Loading