@@ -2,9 +2,19 @@ use std::sync::{Arc, Mutex, MutexGuard};
2
2
3
3
use candle_core:: { Result , Tensor , D } ;
4
4
5
- use crate :: { get_mut_arcmutex, sequence:: Sequence } ;
5
+ use crate :: {
6
+ get_mut_arcmutex,
7
+ pipeline:: { CacheManagerMixin , MetadataMixin } ,
8
+ sequence:: Sequence ,
9
+ } ;
6
10
7
- use super :: { CacheManagerMixin , MetadataMixin } ;
11
+ mod full_cache;
12
+ mod rotating_cache;
13
+ mod single_cache;
14
+
15
+ pub use full_cache:: { EitherCache , LayerCaches } ;
16
+ pub use rotating_cache:: RotatingCache ;
17
+ pub use single_cache:: SingleCache ;
8
18
9
19
pub trait CacheManager < T : CacheManagerMixin + MetadataMixin + ?Sized > {
10
20
fn clone_in_cache (
@@ -23,278 +33,6 @@ pub trait CacheManager<T: CacheManagerMixin + MetadataMixin + ?Sized> {
23
33
) ;
24
34
}
25
35
26
- pub type LayerCaches = Vec < Option < ( Tensor , Tensor ) > > ;
27
-
28
- #[ derive( Debug , Clone ) ]
29
- pub enum EitherCache {
30
- Normal ( Arc < Mutex < NormalCache > > ) ,
31
- Full ( Cache ) ,
32
- }
33
-
34
- impl EitherCache {
35
- /// Panics otherwise!
36
- pub fn full ( & self ) -> & Cache {
37
- match self {
38
- Self :: Full ( full) => full,
39
- Self :: Normal ( _) => panic ! ( "Got normal cache, expected full cache." ) ,
40
- }
41
- }
42
- /// Panics otherwise!
43
- pub fn normal ( & self ) -> MutexGuard < ' _ , NormalCache > {
44
- match self {
45
- Self :: Normal ( normal) => normal. lock ( ) . unwrap ( ) ,
46
- Self :: Full ( _) => panic ! ( "Got full cache, expected normal cache." ) ,
47
- }
48
- }
49
- }
50
-
51
- #[ derive( Debug , Clone ) ]
52
- pub struct SingleCache {
53
- // all_data is an option on a Tensor, this makes it possible to only create the actual tensor
54
- // on the first call where the batch size is easily known.
55
- // Also this makes it safe to clone a KvCache that has been reset (as in it will not share
56
- // its internal state with the cloned instance).
57
- pub all_data : Option < Tensor > ,
58
- pub dim : usize ,
59
- pub current_seq_len : usize ,
60
- pub capacity_seq_len : usize ,
61
- pub max_seq_len : usize ,
62
- }
63
-
64
- impl SingleCache {
65
- pub fn new ( dim : usize , max_seq_len : usize , capacity_seq_len : usize ) -> Self {
66
- Self {
67
- all_data : None ,
68
- dim,
69
- current_seq_len : 0 ,
70
- max_seq_len,
71
- capacity_seq_len,
72
- }
73
- }
74
-
75
- pub fn dim ( & self ) -> usize {
76
- self . dim
77
- }
78
-
79
- pub fn current_seq_len ( & self ) -> usize {
80
- self . current_seq_len
81
- }
82
-
83
- pub fn max_seq_len ( & self ) -> usize {
84
- self . max_seq_len
85
- }
86
-
87
- pub fn all_data ( & self ) -> & Option < Tensor > {
88
- & self . all_data
89
- }
90
-
91
- pub fn current_data ( & self ) -> Result < Option < Tensor > > {
92
- let data = match self . all_data . as_ref ( ) {
93
- None => None ,
94
- Some ( d) => Some ( d. narrow ( self . dim , 0 , self . current_seq_len ) ?) ,
95
- } ;
96
- Ok ( data)
97
- }
98
-
99
- pub fn reset ( & mut self ) {
100
- self . current_seq_len = 0 ;
101
- self . all_data = None ;
102
- }
103
-
104
- pub fn set_len ( & mut self , len : usize ) -> candle_core:: Result < ( ) > {
105
- self . current_seq_len = len;
106
- Ok ( ( ) )
107
- }
108
-
109
- pub fn append ( & mut self , src : & Tensor ) -> Result < ( ) > {
110
- let seq_len = src. dim ( self . dim ) ?;
111
- // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
112
- // self.all_data.get_or_insert_with.
113
- if self . all_data . is_none ( ) {
114
- let mut shape = src. dims ( ) . to_vec ( ) ;
115
- shape[ self . dim ] = self . capacity_seq_len ;
116
- let ad = Tensor :: zeros ( shape, src. dtype ( ) , src. device ( ) ) ?;
117
- self . all_data = Some ( ad) ;
118
- } ;
119
-
120
- // Expand kv cache
121
- if self . current_seq_len + seq_len > self . capacity_seq_len {
122
- let diff = self . current_seq_len + seq_len - self . capacity_seq_len ;
123
- let n_blocks_needed = diff. div_ceil ( NormalCache :: CACHE_GROW_SIZE ) ;
124
- self . capacity_seq_len += n_blocks_needed * NormalCache :: CACHE_GROW_SIZE ;
125
- if self . capacity_seq_len > self . max_seq_len {
126
- candle_core:: bail!(
127
- "kv-cache: requested capacity ({}) above max seq len ({})" ,
128
- self . capacity_seq_len,
129
- self . max_seq_len
130
- )
131
- }
132
- let mut shape = src. dims ( ) . to_vec ( ) ;
133
- shape[ self . dim ] = self . capacity_seq_len ;
134
- let ad = Tensor :: zeros ( shape, src. dtype ( ) , src. device ( ) ) ?;
135
- ad. slice_set ( self . all_data . as_ref ( ) . unwrap ( ) , self . dim , 0 ) ?;
136
- self . all_data = Some ( ad) ;
137
- }
138
-
139
- let ad = self . all_data . as_mut ( ) . unwrap ( ) ;
140
-
141
- ad. slice_set ( src, self . dim , self . current_seq_len ) ?;
142
- self . current_seq_len += seq_len;
143
- Ok ( ( ) )
144
- }
145
- }
146
-
147
- #[ derive( Debug , Clone ) ]
148
- pub struct RotatingCache {
149
- pub all_data : Option < Tensor > ,
150
- pub dim : usize ,
151
- // `offset` is the current write index in the buffer
152
- pub offset : usize ,
153
- // The total size of the sequence seen so far.
154
- pub current_seq_len : usize ,
155
- // max_seq_len is the size of the rotating buffer, it is actually allowed for the full
156
- // sequence to grow past this limit.
157
- pub max_seq_len : usize ,
158
- pub capacity_seq_len : usize ,
159
- }
160
-
161
- impl RotatingCache {
162
- pub fn new ( dim : usize , max_seq_len : usize , capacity_seq_len : usize ) -> Self {
163
- Self {
164
- all_data : None ,
165
- dim,
166
- offset : 0 ,
167
- current_seq_len : 0 ,
168
- max_seq_len,
169
- capacity_seq_len,
170
- }
171
- }
172
-
173
- pub fn offset ( & self ) -> usize {
174
- self . offset
175
- }
176
-
177
- pub fn dim ( & self ) -> usize {
178
- self . dim
179
- }
180
-
181
- pub fn current_seq_len ( & self ) -> usize {
182
- self . current_seq_len
183
- }
184
-
185
- pub fn max_seq_len ( & self ) -> usize {
186
- self . max_seq_len
187
- }
188
-
189
- pub fn all_data ( & self ) -> & Option < Tensor > {
190
- & self . all_data
191
- }
192
-
193
- pub fn current_data ( & self ) -> Result < Option < Tensor > > {
194
- let data = match self . all_data . as_ref ( ) {
195
- None => None ,
196
- Some ( d) => {
197
- if self . current_seq_len >= self . max_seq_len {
198
- Some ( d. clone ( ) )
199
- } else {
200
- Some ( d. narrow ( self . dim , 0 , self . current_seq_len ) ?)
201
- }
202
- }
203
- } ;
204
- Ok ( data)
205
- }
206
-
207
- pub fn reset ( & mut self ) {
208
- self . offset = 0 ;
209
- self . current_seq_len = 0 ;
210
- self . all_data = None ;
211
- }
212
-
213
- pub fn set_len ( & mut self , len : usize ) -> candle_core:: Result < ( ) > {
214
- // If trying to roll it back past the boundary of max_seq_len, fail early.
215
- if self . current_seq_len - len > self . max_seq_len {
216
- candle_core:: bail!(
217
- "Rotating KV cache (usually for sliding window) tried to reset to len {len} while current is {} and max retained is {}" ,
218
- self . current_seq_len,
219
- self . max_seq_len
220
- ) ;
221
- }
222
- self . current_seq_len = len;
223
- self . offset = len % self . max_seq_len ;
224
- Ok ( ( ) )
225
- }
226
-
227
- pub fn append ( & mut self , src : & Tensor ) -> Result < Tensor > {
228
- let seq_len = src. dim ( self . dim ) ?;
229
- // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
230
- // self.all_data.get_or_insert_with.
231
- if self . all_data . is_none ( ) {
232
- let mut shape = src. dims ( ) . to_vec ( ) ;
233
- shape[ self . dim ] = self . capacity_seq_len ;
234
- let ad = Tensor :: zeros ( shape, src. dtype ( ) , src. device ( ) ) ?;
235
- self . all_data = Some ( ad)
236
- } ;
237
-
238
- // Expand kv cache, this case is a little more complex.
239
- if ( self . current_seq_len + seq_len > self . capacity_seq_len
240
- && self . current_seq_len + seq_len < self . max_seq_len )
241
- || self . current_seq_len == 0
242
- {
243
- let diff = self . current_seq_len + seq_len - self . capacity_seq_len ;
244
- let n_blocks_needed = diff. div_ceil ( NormalCache :: CACHE_GROW_SIZE ) ;
245
- self . capacity_seq_len += n_blocks_needed * NormalCache :: CACHE_GROW_SIZE ;
246
- self . capacity_seq_len = self . capacity_seq_len . min ( self . max_seq_len ) ;
247
- if self . capacity_seq_len > self . max_seq_len {
248
- candle_core:: bail!(
249
- "kv-cache: requested capacity ({}) above max seq len ({})" ,
250
- self . capacity_seq_len,
251
- self . max_seq_len
252
- )
253
- }
254
- let mut shape = src. dims ( ) . to_vec ( ) ;
255
- shape[ self . dim ] = self . capacity_seq_len ;
256
- let ad = Tensor :: zeros ( shape, src. dtype ( ) , src. device ( ) ) ?;
257
- ad. slice_set ( self . all_data . as_ref ( ) . unwrap ( ) , self . dim , 0 ) ?;
258
- self . all_data = Some ( ad) ;
259
- }
260
-
261
- let ad = self . all_data . as_mut ( ) . unwrap ( ) ;
262
-
263
- self . current_seq_len += seq_len;
264
- if seq_len >= self . max_seq_len {
265
- let to_copy = src
266
- . narrow ( self . dim , seq_len - self . max_seq_len , self . max_seq_len ) ?
267
- . contiguous ( ) ?;
268
- ad. slice_set ( & to_copy, self . dim , 0 ) ?;
269
- self . offset = 0 ;
270
- // Here we return `src` rather than `ad` so that all the past can be used.
271
- Ok ( src. clone ( ) )
272
- } else {
273
- let rem_len = self . max_seq_len - self . offset ;
274
- if seq_len <= rem_len {
275
- ad. slice_set ( & src. contiguous ( ) ?, self . dim , self . offset ) ?;
276
- self . offset = ( self . offset + seq_len) % self . max_seq_len ;
277
- } else {
278
- // We have to make two copies here as we go over the boundary of the cache.
279
- if rem_len > 0 {
280
- let src1 = src. narrow ( self . dim , 0 , rem_len) ?. contiguous ( ) ?;
281
- ad. slice_set ( & src1, self . dim , self . offset ) ?;
282
- }
283
- let src2 = src
284
- . narrow ( self . dim , rem_len, seq_len - rem_len) ?
285
- . contiguous ( ) ?;
286
- ad. slice_set ( & src2, self . dim , 0 ) ?;
287
- self . offset = seq_len - rem_len;
288
- }
289
- if self . current_seq_len >= self . max_seq_len {
290
- Ok ( ad. clone ( ) )
291
- } else {
292
- Ok ( ad. narrow ( self . dim , 0 , self . current_seq_len ) ?)
293
- }
294
- }
295
- }
296
- }
297
-
298
36
#[ derive( Debug , Clone ) ]
299
37
pub enum KvCache {
300
38
Normal { k : SingleCache , v : SingleCache } ,
0 commit comments