Skip to content

Commit cd96fa8

Browse files
Add a scattered kv cache. (#2936)
* Add a scattered kv cache. * Update some comments.
1 parent 8a19bb7 commit cd96fa8

File tree

1 file changed

+320
-1
lines changed

1 file changed

+320
-1
lines changed

candle-nn/src/kv_cache.rs

Lines changed: 320 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Cache Implementations
22
//!
3-
use candle::{Device, Result, Tensor};
3+
use candle::{DType, Device, Result, Tensor};
44

55
#[derive(Debug, Clone)]
66
pub struct Cache {
@@ -399,3 +399,322 @@ impl RotatingKvCache {
399399
self.v.reset();
400400
}
401401
}
402+
403+
#[derive(Debug, Clone)]
404+
pub struct IndicesAndMask {
405+
indices: Tensor,
406+
mask: Tensor,
407+
}
408+
409+
impl IndicesAndMask {
410+
pub fn mask(&self) -> &Tensor {
411+
&self.mask
412+
}
413+
}
414+
415+
#[derive(Debug, Clone)]
416+
pub struct ScatteredKvCache {
417+
k: Tensor,
418+
v: Tensor,
419+
context: usize,
420+
}
421+
422+
impl ScatteredKvCache {
423+
pub fn append(
424+
&mut self,
425+
k: &Tensor,
426+
v: &Tensor,
427+
iam: &IndicesAndMask,
428+
) -> Result<(Tensor, Tensor)> {
429+
if self.context <= k.dim(2)? {
430+
return Ok((k.clone(), v.clone()));
431+
}
432+
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
433+
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
434+
self.k.scatter_set(&indices, k, 2)?;
435+
self.v.scatter_set(&indices, v, 2)?;
436+
Ok((self.k.clone(), self.v.clone()))
437+
}
438+
439+
pub fn k(&self) -> &Tensor {
440+
&self.k
441+
}
442+
443+
pub fn v(&self) -> &Tensor {
444+
&self.v
445+
}
446+
}
447+
448+
#[derive(Debug, Clone)]
449+
pub struct ScatteredCacheBuilder {
450+
context: usize,
451+
// The current position in the stream, this can be larger than context.
452+
positions: Vec<usize>,
453+
// The index where the next element will be stored.
454+
indices: Vec<usize>,
455+
dtype: DType,
456+
device: Device,
457+
}
458+
459+
impl ScatteredCacheBuilder {
460+
pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
461+
let positions = vec![0; batch_size];
462+
let indices = vec![0; batch_size];
463+
Ok(Self {
464+
positions,
465+
indices,
466+
context,
467+
dtype,
468+
device: device.clone(),
469+
})
470+
}
471+
472+
pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
473+
let batch_size = self.batch_size();
474+
let shape = (batch_size, num_heads, self.context, head_dim);
475+
let k = Tensor::zeros(shape, self.dtype, self.device())?;
476+
let v = Tensor::zeros(shape, self.dtype, self.device())?;
477+
Ok(ScatteredKvCache {
478+
k,
479+
v,
480+
context: self.context,
481+
})
482+
}
483+
484+
pub fn positions(&self) -> &[usize] {
485+
&self.positions
486+
}
487+
488+
pub fn reset(&mut self) {
489+
self.positions.fill(0);
490+
self.indices.fill(0);
491+
}
492+
493+
pub fn batch_size(&self) -> usize {
494+
self.positions.len()
495+
}
496+
497+
pub fn reset_batch_index(&mut self, batch_index: usize) {
498+
self.positions[batch_index] = 0;
499+
self.indices[batch_index] = 0;
500+
}
501+
502+
#[allow(clippy::needless_range_loop)]
503+
pub fn indices_and_mask(
504+
&mut self,
505+
seq_len: usize,
506+
batch_mask: &[bool],
507+
) -> Result<IndicesAndMask> {
508+
// mask shape is (b, h, t, k)
509+
let context = self.context;
510+
if self.context <= seq_len {
511+
return self.indices_and_mask_abs(seq_len, batch_mask);
512+
}
513+
let mut attention_masks = Vec::with_capacity(self.batch_size());
514+
let mut cache_indices = Vec::with_capacity(self.batch_size());
515+
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
516+
if !batch_mask {
517+
let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
518+
let indices = vec![self.indices[batch_i] as u32; seq_len];
519+
attention_masks.push(masks);
520+
cache_indices.push(indices);
521+
} else {
522+
let start_index = self.indices[batch_i];
523+
let start_pos = self.positions[batch_i];
524+
let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
525+
let mut indices = Vec::with_capacity(seq_len);
526+
let mut all_pos = vec![usize::MAX; context];
527+
if start_pos < context {
528+
for i in 0..start_pos {
529+
all_pos[i] = i;
530+
}
531+
} else {
532+
let offset = start_pos - start_index;
533+
for i in 0..context {
534+
all_pos[i] = if i < start_index {
535+
i + offset
536+
} else {
537+
i + offset - context
538+
};
539+
}
540+
}
541+
for seq_i in 0..seq_len {
542+
let index = self.indices[batch_i];
543+
all_pos[index] = seq_i + start_pos;
544+
indices.push(index as u32);
545+
self.indices[batch_i] += 1;
546+
self.positions[batch_i] += 1;
547+
if self.indices[batch_i] >= self.context {
548+
self.indices[batch_i] = 0;
549+
}
550+
}
551+
552+
for seq_i in 0..seq_len {
553+
let my_pos = seq_i + start_pos;
554+
let mask = all_pos
555+
.iter()
556+
.map(|&pos| {
557+
if pos <= my_pos {
558+
0.0
559+
} else {
560+
f32::NEG_INFINITY
561+
}
562+
})
563+
.collect::<Vec<f32>>();
564+
masks.push(mask);
565+
}
566+
567+
attention_masks.push(masks);
568+
cache_indices.push(indices);
569+
}
570+
}
571+
// Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
572+
// up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
573+
let attention_masks = attention_masks
574+
.into_iter()
575+
.flat_map(|m| m.into_iter().flatten())
576+
.collect::<Vec<f32>>();
577+
let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
578+
.to_dtype(self.dtype)?;
579+
let indices = Tensor::new(cache_indices, self.device())?;
580+
Ok(IndicesAndMask { indices, mask })
581+
}
582+
583+
pub fn device(&self) -> &Device {
584+
&self.device
585+
}
586+
587+
#[allow(clippy::needless_range_loop)]
588+
fn indices_and_mask_abs(
589+
&mut self,
590+
seq_len: usize,
591+
batch_mask: &[bool],
592+
) -> Result<IndicesAndMask> {
593+
let mask = self.get_mask_abs(seq_len, seq_len)?;
594+
let mut cache_indices = Vec::with_capacity(self.batch_size());
595+
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
596+
if !batch_mask {
597+
let indices = vec![self.indices[batch_i] as u32; seq_len];
598+
cache_indices.push(indices);
599+
} else {
600+
let mut indices = Vec::with_capacity(seq_len);
601+
for _ in 0..seq_len {
602+
let index = self.indices[batch_i];
603+
indices.push(index as u32);
604+
self.indices[batch_i] += 1;
605+
self.positions[batch_i] += 1;
606+
if self.indices[batch_i] >= self.context {
607+
self.indices[batch_i] = 0;
608+
}
609+
}
610+
cache_indices.push(indices);
611+
}
612+
}
613+
let indices = Tensor::new(cache_indices, self.device())?;
614+
Ok(IndicesAndMask { indices, mask })
615+
}
616+
617+
fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
618+
let context = self.context;
619+
let mask: Vec<_> = (0..size1)
620+
.flat_map(|i| {
621+
(0..size2).map(move |j| {
622+
if size1 + j > size2 + i || size1 + j + context < size2 + i {
623+
f32::NEG_INFINITY
624+
} else {
625+
0.0
626+
}
627+
})
628+
})
629+
.collect();
630+
Tensor::from_slice(&mask, (size1, size2), self.device())
631+
}
632+
}
633+
634+
#[cfg(test)]
635+
mod tests {
636+
use super::*;
637+
use candle::IndexOp;
638+
639+
#[test]
640+
fn test_scattered_kv_cache() -> Result<()> {
641+
let device = Device::Cpu;
642+
let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
643+
let inf = f32::INFINITY;
644+
645+
let iam = cache.indices_and_mask(1, &[true, false])?;
646+
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
647+
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
648+
assert_eq!(
649+
mask,
650+
[[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
651+
);
652+
653+
let iam = cache.indices_and_mask(1, &[true, false])?;
654+
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
655+
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
656+
assert_eq!(
657+
mask,
658+
[[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
659+
);
660+
661+
let iam = cache.indices_and_mask(3, &[false, true])?;
662+
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
663+
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
664+
assert_eq!(
665+
mask,
666+
[
667+
[
668+
[0.0, 0.0, 0.0, 0.0, 0.0],
669+
[0.0, 0.0, 0.0, 0.0, 0.0],
670+
[0.0, 0.0, 0.0, 0.0, 0.0]
671+
],
672+
[
673+
[0.0, -inf, -inf, -inf, -inf],
674+
[0.0, 0.0, -inf, -inf, -inf],
675+
[0.0, 0.0, 0.0, -inf, -inf]
676+
]
677+
]
678+
);
679+
680+
let iam = cache.indices_and_mask(3, &[true, true])?;
681+
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
682+
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
683+
assert_eq!(
684+
mask,
685+
[
686+
[
687+
[0.0, 0.0, 0.0, -inf, -inf],
688+
[0.0, 0.0, 0.0, 0.0, -inf],
689+
[0.0, 0.0, 0.0, 0.0, 0.0]
690+
],
691+
[
692+
[-inf, 0.0, 0.0, 0.0, -inf],
693+
[-inf, 0.0, 0.0, 0.0, 0.0],
694+
[0.0, 0.0, 0.0, 0.0, 0.0]
695+
]
696+
]
697+
);
698+
699+
let iam = cache.indices_and_mask(1, &[true, false])?;
700+
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
701+
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
702+
assert_eq!(
703+
mask,
704+
[[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
705+
);
706+
707+
let iam = cache.indices_and_mask(2, &[true, false])?;
708+
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
709+
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
710+
assert_eq!(
711+
mask,
712+
[
713+
[[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
714+
[[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
715+
]
716+
);
717+
718+
Ok(())
719+
}
720+
}

0 commit comments

Comments
 (0)