|
1 | 1 | //! Cache Implementations
|
2 | 2 | //!
|
3 |
| -use candle::{Device, Result, Tensor}; |
| 3 | +use candle::{DType, Device, Result, Tensor}; |
4 | 4 |
|
5 | 5 | #[derive(Debug, Clone)]
|
6 | 6 | pub struct Cache {
|
@@ -399,3 +399,322 @@ impl RotatingKvCache {
|
399 | 399 | self.v.reset();
|
400 | 400 | }
|
401 | 401 | }
|
| 402 | + |
| 403 | +#[derive(Debug, Clone)] |
| 404 | +pub struct IndicesAndMask { |
| 405 | + indices: Tensor, |
| 406 | + mask: Tensor, |
| 407 | +} |
| 408 | + |
| 409 | +impl IndicesAndMask { |
| 410 | + pub fn mask(&self) -> &Tensor { |
| 411 | + &self.mask |
| 412 | + } |
| 413 | +} |
| 414 | + |
| 415 | +#[derive(Debug, Clone)] |
| 416 | +pub struct ScatteredKvCache { |
| 417 | + k: Tensor, |
| 418 | + v: Tensor, |
| 419 | + context: usize, |
| 420 | +} |
| 421 | + |
| 422 | +impl ScatteredKvCache { |
| 423 | + pub fn append( |
| 424 | + &mut self, |
| 425 | + k: &Tensor, |
| 426 | + v: &Tensor, |
| 427 | + iam: &IndicesAndMask, |
| 428 | + ) -> Result<(Tensor, Tensor)> { |
| 429 | + if self.context <= k.dim(2)? { |
| 430 | + return Ok((k.clone(), v.clone())); |
| 431 | + } |
| 432 | + let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?; |
| 433 | + let indices = indices.broadcast_as(k.shape())?.contiguous()?; |
| 434 | + self.k.scatter_set(&indices, k, 2)?; |
| 435 | + self.v.scatter_set(&indices, v, 2)?; |
| 436 | + Ok((self.k.clone(), self.v.clone())) |
| 437 | + } |
| 438 | + |
| 439 | + pub fn k(&self) -> &Tensor { |
| 440 | + &self.k |
| 441 | + } |
| 442 | + |
| 443 | + pub fn v(&self) -> &Tensor { |
| 444 | + &self.v |
| 445 | + } |
| 446 | +} |
| 447 | + |
| 448 | +#[derive(Debug, Clone)] |
| 449 | +pub struct ScatteredCacheBuilder { |
| 450 | + context: usize, |
| 451 | + // The current position in the stream, this can be larger than context. |
| 452 | + positions: Vec<usize>, |
| 453 | + // The index where the next element will be stored. |
| 454 | + indices: Vec<usize>, |
| 455 | + dtype: DType, |
| 456 | + device: Device, |
| 457 | +} |
| 458 | + |
| 459 | +impl ScatteredCacheBuilder { |
| 460 | + pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> { |
| 461 | + let positions = vec![0; batch_size]; |
| 462 | + let indices = vec![0; batch_size]; |
| 463 | + Ok(Self { |
| 464 | + positions, |
| 465 | + indices, |
| 466 | + context, |
| 467 | + dtype, |
| 468 | + device: device.clone(), |
| 469 | + }) |
| 470 | + } |
| 471 | + |
| 472 | + pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> { |
| 473 | + let batch_size = self.batch_size(); |
| 474 | + let shape = (batch_size, num_heads, self.context, head_dim); |
| 475 | + let k = Tensor::zeros(shape, self.dtype, self.device())?; |
| 476 | + let v = Tensor::zeros(shape, self.dtype, self.device())?; |
| 477 | + Ok(ScatteredKvCache { |
| 478 | + k, |
| 479 | + v, |
| 480 | + context: self.context, |
| 481 | + }) |
| 482 | + } |
| 483 | + |
| 484 | + pub fn positions(&self) -> &[usize] { |
| 485 | + &self.positions |
| 486 | + } |
| 487 | + |
| 488 | + pub fn reset(&mut self) { |
| 489 | + self.positions.fill(0); |
| 490 | + self.indices.fill(0); |
| 491 | + } |
| 492 | + |
| 493 | + pub fn batch_size(&self) -> usize { |
| 494 | + self.positions.len() |
| 495 | + } |
| 496 | + |
| 497 | + pub fn reset_batch_index(&mut self, batch_index: usize) { |
| 498 | + self.positions[batch_index] = 0; |
| 499 | + self.indices[batch_index] = 0; |
| 500 | + } |
| 501 | + |
| 502 | + #[allow(clippy::needless_range_loop)] |
| 503 | + pub fn indices_and_mask( |
| 504 | + &mut self, |
| 505 | + seq_len: usize, |
| 506 | + batch_mask: &[bool], |
| 507 | + ) -> Result<IndicesAndMask> { |
| 508 | + // mask shape is (b, h, t, k) |
| 509 | + let context = self.context; |
| 510 | + if self.context <= seq_len { |
| 511 | + return self.indices_and_mask_abs(seq_len, batch_mask); |
| 512 | + } |
| 513 | + let mut attention_masks = Vec::with_capacity(self.batch_size()); |
| 514 | + let mut cache_indices = Vec::with_capacity(self.batch_size()); |
| 515 | + for (batch_i, &batch_mask) in batch_mask.iter().enumerate() { |
| 516 | + if !batch_mask { |
| 517 | + let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len]; |
| 518 | + let indices = vec![self.indices[batch_i] as u32; seq_len]; |
| 519 | + attention_masks.push(masks); |
| 520 | + cache_indices.push(indices); |
| 521 | + } else { |
| 522 | + let start_index = self.indices[batch_i]; |
| 523 | + let start_pos = self.positions[batch_i]; |
| 524 | + let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len); |
| 525 | + let mut indices = Vec::with_capacity(seq_len); |
| 526 | + let mut all_pos = vec![usize::MAX; context]; |
| 527 | + if start_pos < context { |
| 528 | + for i in 0..start_pos { |
| 529 | + all_pos[i] = i; |
| 530 | + } |
| 531 | + } else { |
| 532 | + let offset = start_pos - start_index; |
| 533 | + for i in 0..context { |
| 534 | + all_pos[i] = if i < start_index { |
| 535 | + i + offset |
| 536 | + } else { |
| 537 | + i + offset - context |
| 538 | + }; |
| 539 | + } |
| 540 | + } |
| 541 | + for seq_i in 0..seq_len { |
| 542 | + let index = self.indices[batch_i]; |
| 543 | + all_pos[index] = seq_i + start_pos; |
| 544 | + indices.push(index as u32); |
| 545 | + self.indices[batch_i] += 1; |
| 546 | + self.positions[batch_i] += 1; |
| 547 | + if self.indices[batch_i] >= self.context { |
| 548 | + self.indices[batch_i] = 0; |
| 549 | + } |
| 550 | + } |
| 551 | + |
| 552 | + for seq_i in 0..seq_len { |
| 553 | + let my_pos = seq_i + start_pos; |
| 554 | + let mask = all_pos |
| 555 | + .iter() |
| 556 | + .map(|&pos| { |
| 557 | + if pos <= my_pos { |
| 558 | + 0.0 |
| 559 | + } else { |
| 560 | + f32::NEG_INFINITY |
| 561 | + } |
| 562 | + }) |
| 563 | + .collect::<Vec<f32>>(); |
| 564 | + masks.push(mask); |
| 565 | + } |
| 566 | + |
| 567 | + attention_masks.push(masks); |
| 568 | + cache_indices.push(indices); |
| 569 | + } |
| 570 | + } |
| 571 | + // Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends |
| 572 | + // up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1. |
| 573 | + let attention_masks = attention_masks |
| 574 | + .into_iter() |
| 575 | + .flat_map(|m| m.into_iter().flatten()) |
| 576 | + .collect::<Vec<f32>>(); |
| 577 | + let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())? |
| 578 | + .to_dtype(self.dtype)?; |
| 579 | + let indices = Tensor::new(cache_indices, self.device())?; |
| 580 | + Ok(IndicesAndMask { indices, mask }) |
| 581 | + } |
| 582 | + |
| 583 | + pub fn device(&self) -> &Device { |
| 584 | + &self.device |
| 585 | + } |
| 586 | + |
| 587 | + #[allow(clippy::needless_range_loop)] |
| 588 | + fn indices_and_mask_abs( |
| 589 | + &mut self, |
| 590 | + seq_len: usize, |
| 591 | + batch_mask: &[bool], |
| 592 | + ) -> Result<IndicesAndMask> { |
| 593 | + let mask = self.get_mask_abs(seq_len, seq_len)?; |
| 594 | + let mut cache_indices = Vec::with_capacity(self.batch_size()); |
| 595 | + for (batch_i, &batch_mask) in batch_mask.iter().enumerate() { |
| 596 | + if !batch_mask { |
| 597 | + let indices = vec![self.indices[batch_i] as u32; seq_len]; |
| 598 | + cache_indices.push(indices); |
| 599 | + } else { |
| 600 | + let mut indices = Vec::with_capacity(seq_len); |
| 601 | + for _ in 0..seq_len { |
| 602 | + let index = self.indices[batch_i]; |
| 603 | + indices.push(index as u32); |
| 604 | + self.indices[batch_i] += 1; |
| 605 | + self.positions[batch_i] += 1; |
| 606 | + if self.indices[batch_i] >= self.context { |
| 607 | + self.indices[batch_i] = 0; |
| 608 | + } |
| 609 | + } |
| 610 | + cache_indices.push(indices); |
| 611 | + } |
| 612 | + } |
| 613 | + let indices = Tensor::new(cache_indices, self.device())?; |
| 614 | + Ok(IndicesAndMask { indices, mask }) |
| 615 | + } |
| 616 | + |
| 617 | + fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> { |
| 618 | + let context = self.context; |
| 619 | + let mask: Vec<_> = (0..size1) |
| 620 | + .flat_map(|i| { |
| 621 | + (0..size2).map(move |j| { |
| 622 | + if size1 + j > size2 + i || size1 + j + context < size2 + i { |
| 623 | + f32::NEG_INFINITY |
| 624 | + } else { |
| 625 | + 0.0 |
| 626 | + } |
| 627 | + }) |
| 628 | + }) |
| 629 | + .collect(); |
| 630 | + Tensor::from_slice(&mask, (size1, size2), self.device()) |
| 631 | + } |
| 632 | +} |
| 633 | + |
| 634 | +#[cfg(test)] |
| 635 | +mod tests { |
| 636 | + use super::*; |
| 637 | + use candle::IndexOp; |
| 638 | + |
| 639 | + #[test] |
| 640 | + fn test_scattered_kv_cache() -> Result<()> { |
| 641 | + let device = Device::Cpu; |
| 642 | + let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?; |
| 643 | + let inf = f32::INFINITY; |
| 644 | + |
| 645 | + let iam = cache.indices_and_mask(1, &[true, false])?; |
| 646 | + let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?; |
| 647 | + assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]); |
| 648 | + assert_eq!( |
| 649 | + mask, |
| 650 | + [[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]] |
| 651 | + ); |
| 652 | + |
| 653 | + let iam = cache.indices_and_mask(1, &[true, false])?; |
| 654 | + let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?; |
| 655 | + assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]); |
| 656 | + assert_eq!( |
| 657 | + mask, |
| 658 | + [[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]] |
| 659 | + ); |
| 660 | + |
| 661 | + let iam = cache.indices_and_mask(3, &[false, true])?; |
| 662 | + let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?; |
| 663 | + assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]); |
| 664 | + assert_eq!( |
| 665 | + mask, |
| 666 | + [ |
| 667 | + [ |
| 668 | + [0.0, 0.0, 0.0, 0.0, 0.0], |
| 669 | + [0.0, 0.0, 0.0, 0.0, 0.0], |
| 670 | + [0.0, 0.0, 0.0, 0.0, 0.0] |
| 671 | + ], |
| 672 | + [ |
| 673 | + [0.0, -inf, -inf, -inf, -inf], |
| 674 | + [0.0, 0.0, -inf, -inf, -inf], |
| 675 | + [0.0, 0.0, 0.0, -inf, -inf] |
| 676 | + ] |
| 677 | + ] |
| 678 | + ); |
| 679 | + |
| 680 | + let iam = cache.indices_and_mask(3, &[true, true])?; |
| 681 | + let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?; |
| 682 | + assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]); |
| 683 | + assert_eq!( |
| 684 | + mask, |
| 685 | + [ |
| 686 | + [ |
| 687 | + [0.0, 0.0, 0.0, -inf, -inf], |
| 688 | + [0.0, 0.0, 0.0, 0.0, -inf], |
| 689 | + [0.0, 0.0, 0.0, 0.0, 0.0] |
| 690 | + ], |
| 691 | + [ |
| 692 | + [-inf, 0.0, 0.0, 0.0, -inf], |
| 693 | + [-inf, 0.0, 0.0, 0.0, 0.0], |
| 694 | + [0.0, 0.0, 0.0, 0.0, 0.0] |
| 695 | + ] |
| 696 | + ] |
| 697 | + ); |
| 698 | + |
| 699 | + let iam = cache.indices_and_mask(1, &[true, false])?; |
| 700 | + let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?; |
| 701 | + assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]); |
| 702 | + assert_eq!( |
| 703 | + mask, |
| 704 | + [[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]] |
| 705 | + ); |
| 706 | + |
| 707 | + let iam = cache.indices_and_mask(2, &[true, false])?; |
| 708 | + let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?; |
| 709 | + assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]); |
| 710 | + assert_eq!( |
| 711 | + mask, |
| 712 | + [ |
| 713 | + [[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]], |
| 714 | + [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]] |
| 715 | + ] |
| 716 | + ); |
| 717 | + |
| 718 | + Ok(()) |
| 719 | + } |
| 720 | +} |
0 commit comments