Skip to content

Commit ca5794d

Browse files
authored
Refactor KV cache manager (#1315)
* Refactor kv cache * Refactor caches * Fix some overflows
1 parent b17455c commit ca5794d

File tree

11 files changed

+318
-286
lines changed

11 files changed

+318
-286
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use std::sync::{Arc, Mutex, MutexGuard};
2+
3+
use candle_core::Tensor;
4+
5+
use super::{Cache, NormalCache};
6+
7+
pub type LayerCaches = Vec<Option<(Tensor, Tensor)>>;
8+
9+
#[derive(Debug, Clone)]
10+
pub enum EitherCache {
11+
Normal(Arc<Mutex<NormalCache>>),
12+
Full(Cache),
13+
}
14+
15+
impl EitherCache {
16+
/// Panics otherwise!
17+
pub fn full(&self) -> &Cache {
18+
match self {
19+
Self::Full(full) => full,
20+
Self::Normal(_) => panic!("Got normal cache, expected full cache."),
21+
}
22+
}
23+
24+
/// Panics otherwise!
25+
pub fn normal(&self) -> MutexGuard<'_, NormalCache> {
26+
match self {
27+
Self::Normal(normal) => normal.lock().unwrap(),
28+
Self::Full(_) => panic!("Got full cache, expected normal cache."),
29+
}
30+
}
31+
}

mistralrs-core/src/pipeline/cache_manager.rs renamed to mistralrs-core/src/kv_cache/mod.rs

Lines changed: 12 additions & 274 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,19 @@ use std::sync::{Arc, Mutex, MutexGuard};
22

33
use candle_core::{Result, Tensor, D};
44

5-
use crate::{get_mut_arcmutex, sequence::Sequence};
5+
use crate::{
6+
get_mut_arcmutex,
7+
pipeline::{CacheManagerMixin, MetadataMixin},
8+
sequence::Sequence,
9+
};
610

7-
use super::{CacheManagerMixin, MetadataMixin};
11+
mod full_cache;
12+
mod rotating_cache;
13+
mod single_cache;
14+
15+
pub use full_cache::{EitherCache, LayerCaches};
16+
pub use rotating_cache::RotatingCache;
17+
pub use single_cache::SingleCache;
818

919
pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
1020
fn clone_in_cache(
@@ -23,278 +33,6 @@ pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
2333
);
2434
}
2535

26-
pub type LayerCaches = Vec<Option<(Tensor, Tensor)>>;
27-
28-
#[derive(Debug, Clone)]
29-
pub enum EitherCache {
30-
Normal(Arc<Mutex<NormalCache>>),
31-
Full(Cache),
32-
}
33-
34-
impl EitherCache {
35-
/// Panics otherwise!
36-
pub fn full(&self) -> &Cache {
37-
match self {
38-
Self::Full(full) => full,
39-
Self::Normal(_) => panic!("Got normal cache, expected full cache."),
40-
}
41-
}
42-
/// Panics otherwise!
43-
pub fn normal(&self) -> MutexGuard<'_, NormalCache> {
44-
match self {
45-
Self::Normal(normal) => normal.lock().unwrap(),
46-
Self::Full(_) => panic!("Got full cache, expected normal cache."),
47-
}
48-
}
49-
}
50-
51-
#[derive(Debug, Clone)]
52-
pub struct SingleCache {
53-
// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
54-
// on the first call where the batch size is easily known.
55-
// Also this makes it safe to clone a KvCache that has been reset (as in it will not share
56-
// its internal state with the cloned instance).
57-
pub all_data: Option<Tensor>,
58-
pub dim: usize,
59-
pub current_seq_len: usize,
60-
pub capacity_seq_len: usize,
61-
pub max_seq_len: usize,
62-
}
63-
64-
impl SingleCache {
65-
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
66-
Self {
67-
all_data: None,
68-
dim,
69-
current_seq_len: 0,
70-
max_seq_len,
71-
capacity_seq_len,
72-
}
73-
}
74-
75-
pub fn dim(&self) -> usize {
76-
self.dim
77-
}
78-
79-
pub fn current_seq_len(&self) -> usize {
80-
self.current_seq_len
81-
}
82-
83-
pub fn max_seq_len(&self) -> usize {
84-
self.max_seq_len
85-
}
86-
87-
pub fn all_data(&self) -> &Option<Tensor> {
88-
&self.all_data
89-
}
90-
91-
pub fn current_data(&self) -> Result<Option<Tensor>> {
92-
let data = match self.all_data.as_ref() {
93-
None => None,
94-
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
95-
};
96-
Ok(data)
97-
}
98-
99-
pub fn reset(&mut self) {
100-
self.current_seq_len = 0;
101-
self.all_data = None;
102-
}
103-
104-
pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
105-
self.current_seq_len = len;
106-
Ok(())
107-
}
108-
109-
pub fn append(&mut self, src: &Tensor) -> Result<()> {
110-
let seq_len = src.dim(self.dim)?;
111-
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
112-
// self.all_data.get_or_insert_with.
113-
if self.all_data.is_none() {
114-
let mut shape = src.dims().to_vec();
115-
shape[self.dim] = self.capacity_seq_len;
116-
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
117-
self.all_data = Some(ad);
118-
};
119-
120-
// Expand kv cache
121-
if self.current_seq_len + seq_len > self.capacity_seq_len {
122-
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
123-
let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
124-
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
125-
if self.capacity_seq_len > self.max_seq_len {
126-
candle_core::bail!(
127-
"kv-cache: requested capacity ({}) above max seq len ({})",
128-
self.capacity_seq_len,
129-
self.max_seq_len
130-
)
131-
}
132-
let mut shape = src.dims().to_vec();
133-
shape[self.dim] = self.capacity_seq_len;
134-
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
135-
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
136-
self.all_data = Some(ad);
137-
}
138-
139-
let ad = self.all_data.as_mut().unwrap();
140-
141-
ad.slice_set(src, self.dim, self.current_seq_len)?;
142-
self.current_seq_len += seq_len;
143-
Ok(())
144-
}
145-
}
146-
147-
#[derive(Debug, Clone)]
148-
pub struct RotatingCache {
149-
pub all_data: Option<Tensor>,
150-
pub dim: usize,
151-
// `offset` is the current write index in the buffer
152-
pub offset: usize,
153-
// The total size of the sequence seen so far.
154-
pub current_seq_len: usize,
155-
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
156-
// sequence to grow past this limit.
157-
pub max_seq_len: usize,
158-
pub capacity_seq_len: usize,
159-
}
160-
161-
impl RotatingCache {
162-
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
163-
Self {
164-
all_data: None,
165-
dim,
166-
offset: 0,
167-
current_seq_len: 0,
168-
max_seq_len,
169-
capacity_seq_len,
170-
}
171-
}
172-
173-
pub fn offset(&self) -> usize {
174-
self.offset
175-
}
176-
177-
pub fn dim(&self) -> usize {
178-
self.dim
179-
}
180-
181-
pub fn current_seq_len(&self) -> usize {
182-
self.current_seq_len
183-
}
184-
185-
pub fn max_seq_len(&self) -> usize {
186-
self.max_seq_len
187-
}
188-
189-
pub fn all_data(&self) -> &Option<Tensor> {
190-
&self.all_data
191-
}
192-
193-
pub fn current_data(&self) -> Result<Option<Tensor>> {
194-
let data = match self.all_data.as_ref() {
195-
None => None,
196-
Some(d) => {
197-
if self.current_seq_len >= self.max_seq_len {
198-
Some(d.clone())
199-
} else {
200-
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
201-
}
202-
}
203-
};
204-
Ok(data)
205-
}
206-
207-
pub fn reset(&mut self) {
208-
self.offset = 0;
209-
self.current_seq_len = 0;
210-
self.all_data = None;
211-
}
212-
213-
pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
214-
// If trying to roll it back past the boundary of max_seq_len, fail early.
215-
if self.current_seq_len - len > self.max_seq_len {
216-
candle_core::bail!(
217-
"Rotating KV cache (usually for sliding window) tried to reset to len {len} while current is {} and max retained is {}",
218-
self.current_seq_len,
219-
self.max_seq_len
220-
);
221-
}
222-
self.current_seq_len = len;
223-
self.offset = len % self.max_seq_len;
224-
Ok(())
225-
}
226-
227-
pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
228-
let seq_len = src.dim(self.dim)?;
229-
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
230-
// self.all_data.get_or_insert_with.
231-
if self.all_data.is_none() {
232-
let mut shape = src.dims().to_vec();
233-
shape[self.dim] = self.capacity_seq_len;
234-
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
235-
self.all_data = Some(ad)
236-
};
237-
238-
// Expand kv cache, this case is a little more complex.
239-
if (self.current_seq_len + seq_len > self.capacity_seq_len
240-
&& self.current_seq_len + seq_len < self.max_seq_len)
241-
|| self.current_seq_len == 0
242-
{
243-
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
244-
let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
245-
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
246-
self.capacity_seq_len = self.capacity_seq_len.min(self.max_seq_len);
247-
if self.capacity_seq_len > self.max_seq_len {
248-
candle_core::bail!(
249-
"kv-cache: requested capacity ({}) above max seq len ({})",
250-
self.capacity_seq_len,
251-
self.max_seq_len
252-
)
253-
}
254-
let mut shape = src.dims().to_vec();
255-
shape[self.dim] = self.capacity_seq_len;
256-
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
257-
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
258-
self.all_data = Some(ad);
259-
}
260-
261-
let ad = self.all_data.as_mut().unwrap();
262-
263-
self.current_seq_len += seq_len;
264-
if seq_len >= self.max_seq_len {
265-
let to_copy = src
266-
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
267-
.contiguous()?;
268-
ad.slice_set(&to_copy, self.dim, 0)?;
269-
self.offset = 0;
270-
// Here we return `src` rather than `ad` so that all the past can be used.
271-
Ok(src.clone())
272-
} else {
273-
let rem_len = self.max_seq_len - self.offset;
274-
if seq_len <= rem_len {
275-
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
276-
self.offset = (self.offset + seq_len) % self.max_seq_len;
277-
} else {
278-
// We have to make two copies here as we go over the boundary of the cache.
279-
if rem_len > 0 {
280-
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
281-
ad.slice_set(&src1, self.dim, self.offset)?;
282-
}
283-
let src2 = src
284-
.narrow(self.dim, rem_len, seq_len - rem_len)?
285-
.contiguous()?;
286-
ad.slice_set(&src2, self.dim, 0)?;
287-
self.offset = seq_len - rem_len;
288-
}
289-
if self.current_seq_len >= self.max_seq_len {
290-
Ok(ad.clone())
291-
} else {
292-
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
293-
}
294-
}
295-
}
296-
}
297-
29836
#[derive(Debug, Clone)]
29937
pub enum KvCache {
30038
Normal { k: SingleCache, v: SingleCache },

0 commit comments

Comments
 (0)