@@ -1204,7 +1204,7 @@ def __init__(
1204
1204
config .num_attention_heads
1205
1205
if getattr (config , "num_key_value_heads" , None ) is None
1206
1206
else config .num_key_value_heads
1207
- )
1207
+ ) // 8 # TODO use TP!
1208
1208
1209
1209
self .key_cache : List [torch .Tensor ] = []
1210
1210
self .value_cache : List [torch .Tensor ] = []
@@ -1663,84 +1663,75 @@ def __init__(
1663
1663
max_batch_size : int ,
1664
1664
max_cache_len : Optional [int ] = None ,
1665
1665
device : Union [torch .device , str , None ] = None ,
1666
- dtype : torch .dtype = torch .float32 ,
1666
+ dtype : torch .dtype = torch .bfloat16 ,
1667
1667
layer_device_map : Optional [Dict [int , Union [str , torch .device , int ]]] = None ,
1668
1668
) -> None :
1669
1669
super ().__init__ ()
1670
1670
if not hasattr (config , "sliding_window" ) or config .sliding_window is None :
1671
- raise ValueError (
1672
- "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
1673
- "sliding window attention, please check if there is a `sliding_window` field in the model "
1674
- "config and it's not set to None."
1675
- )
1671
+ self .sliding_window = getattr (config .get_text_config (), "attention_chunk_size" , 8092 )
1672
+ else :
1673
+ self .sliding_window = config .sliding_window
1676
1674
self .max_cache_len = max_cache_len
1677
1675
self .max_batch_size = max_batch_size
1678
- # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
1679
- self .head_dim = (
1680
- config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
1681
- )
1682
-
1676
+ self .head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
1683
1677
self ._dtype = dtype
1684
- self .num_key_value_heads = (
1685
- config .num_attention_heads if config .num_key_value_heads is None else config .num_key_value_heads
1686
- )
1687
1678
1688
- layer_switch = config .sliding_window_pattern if hasattr (config , "sliding_window_pattern" ) else 2 # 2 is for BC
1689
- self .is_sliding = torch .tensor (
1690
- [bool ((i + 1 ) % layer_switch ) for i in range (config .num_hidden_layers )], dtype = torch .bool
1691
- )
1679
+ if hasattr (config .get_text_config (), "no_rope_layers" ):
1680
+ self .is_sliding = torch .tensor (config .no_rope_layers )
1681
+ else :
1682
+ layer_switch = getattr (config , "sliding_window_pattern" , 2 )
1683
+ self .is_sliding = torch .tensor (
1684
+ [bool ((i + 1 ) % layer_switch ) for i in range (config .num_hidden_layers )], dtype = torch .bool
1685
+ )
1686
+
1692
1687
self .key_cache : List [torch .Tensor ] = []
1693
1688
self .value_cache : List [torch .Tensor ] = []
1694
- global_cache_shape = (self .max_batch_size , self .num_key_value_heads , max_cache_len , self .head_dim )
1689
+ self .cumulative_length = [0 for _ in range (config .num_hidden_layers )]
1690
+
1691
+ def initialise_cache_layer (self , layer_idx , key_states ):
1692
+ if len (self .key_cache ) > layer_idx :
1693
+ return
1694
+
1695
+ num_key_value_heads = key_states .shape [1 ]
1696
+ device = key_states .device
1697
+ global_cache_shape = (self .max_batch_size , num_key_value_heads , self .max_cache_len , self .head_dim )
1695
1698
sliding_cache_shape = (
1696
1699
self .max_batch_size ,
1697
- self . num_key_value_heads ,
1698
- min ( config .sliding_window , max_cache_len ) ,
1700
+ num_key_value_heads ,
1701
+ self .sliding_window ,
1699
1702
self .head_dim ,
1700
1703
)
1701
- device = torch .device (device ) if device is not None and isinstance (device , str ) else None
1702
- for i in range (config .num_hidden_layers ):
1703
- if layer_device_map is not None :
1704
- layer_device = layer_device_map [i ]
1705
- else :
1706
- layer_device = device
1707
- # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1708
- # breaks when updating the cache.
1709
- cache_shape = global_cache_shape if not self .is_sliding [i ] else sliding_cache_shape
1710
- new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1711
- new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1712
- torch ._dynamo .mark_static_address (new_layer_key_cache )
1713
- torch ._dynamo .mark_static_address (new_layer_value_cache )
1714
- self .key_cache .append (new_layer_key_cache )
1715
- self .value_cache .append (new_layer_value_cache )
1704
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1705
+ # breaks when updating the cache.
1706
+ cache_shape = sliding_cache_shape if self .is_sliding [layer_idx ] else global_cache_shape
1707
+ new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = device )
1708
+ new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = device )
1709
+ torch ._dynamo .mark_static_address (new_layer_key_cache )
1710
+ torch ._dynamo .mark_static_address (new_layer_value_cache )
1711
+ self .key_cache .append (new_layer_key_cache )
1712
+ self .value_cache .append (new_layer_value_cache )
1716
1713
1717
1714
def _sliding_update (self , cache_position , layer_idx , key_states , value_states , k_out , v_out , max_cache_len ):
1718
- if cache_position .shape [0 ] > max_cache_len :
1719
- k_out = key_states [:, :, - max_cache_len :, :]
1720
- v_out = value_states [:, :, - max_cache_len :, :]
1721
- # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
1722
- self .key_cache [layer_idx ] += k_out
1723
- self .value_cache [layer_idx ] += v_out
1724
- # we should return the whole states instead of k_out, v_out to take the whole prompt
1725
- # into consideration when building kv cache instead of just throwing away tokens outside of the window
1726
- return key_states , value_states
1727
-
1728
- slicing = torch .ones (max_cache_len , dtype = torch .long , device = value_states .device ).cumsum (0 )
1729
- cache_position = cache_position .clamp (0 , max_cache_len - 1 )
1730
- to_shift = cache_position >= max_cache_len - 1
1731
- indices = (slicing + to_shift [- 1 ].int () - 1 ) % max_cache_len
1732
- k_out = k_out [:, :, indices ]
1733
- v_out = v_out [:, :, indices ]
1734
-
1735
- k_out [:, :, cache_position ] = key_states
1736
- v_out [:, :, cache_position ] = value_states
1737
- # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
1738
- self .key_cache [layer_idx ].zero_ ()
1739
- self .value_cache [layer_idx ].zero_ ()
1740
-
1741
- self .key_cache [layer_idx ] += k_out
1742
- self .value_cache [layer_idx ] += v_out
1743
- return k_out , v_out
1715
+ cumulative_length = self .cumulative_length [layer_idx ]
1716
+ is_full = cumulative_length >= max_cache_len
1717
+ if is_full :
1718
+ full_key_states = torch .cat ((k_out [:, :, 1 :, :], key_states ), dim = - 2 )
1719
+ full_value_states = torch .cat ((v_out [:, :, 1 :, :], value_states ), dim = - 2 )
1720
+ elif not is_full and cumulative_length + key_states .shape [2 ] > max_cache_len :
1721
+ full_key_states = torch .cat ((k_out [:, :, :cumulative_length , :], key_states ), dim = - 2 )
1722
+ full_value_states = torch .cat ((v_out [:, :, :cumulative_length , :], value_states ), dim = - 2 )
1723
+ else :
1724
+ self .key_cache [layer_idx ].index_copy_ (2 , cache_position , key_states )
1725
+ self .value_cache [layer_idx ].index_copy_ (2 , cache_position , value_states )
1726
+ self .cumulative_length [layer_idx ] += key_states .shape [- 2 ]
1727
+ return self .key_cache [layer_idx ], self .value_cache [layer_idx ]
1728
+
1729
+ self .key_cache [layer_idx ].copy_ (full_key_states [:, :, - max_cache_len :, :])
1730
+ self .value_cache [layer_idx ].copy_ (full_value_states [:, :, - max_cache_len :, :])
1731
+ self .cumulative_length [layer_idx ] += key_states .shape [- 2 ]
1732
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
1733
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
1734
+ return full_key_states , full_value_states
1744
1735
1745
1736
def _static_update (self , cache_position , layer_idx , key_states , value_states , k_out , v_out , max_cache_len ):
1746
1737
k_out [:, :, cache_position ] = key_states
@@ -1760,7 +1751,7 @@ def update(
1760
1751
if cache_kwargs is None :
1761
1752
cache_kwargs = {}
1762
1753
cache_position = cache_kwargs .get ("cache_position" )
1763
- sliding_window = cache_kwargs . get ( "sliding_window" )
1754
+ self . initialise_cache_layer ( layer_idx , key_states )
1764
1755
1765
1756
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
1766
1757
# when the cache is initialized in the forward pass (e.g. Gemma2)
@@ -1774,7 +1765,7 @@ def update(
1774
1765
key_states = key_states .to (k_out .dtype )
1775
1766
value_states = value_states .to (v_out .dtype )
1776
1767
1777
- if sliding_window :
1768
+ if self . is_sliding [ layer_idx ] :
1778
1769
update_fn = self ._sliding_update
1779
1770
else :
1780
1771
update_fn = self ._static_update
@@ -1801,6 +1792,8 @@ def get_seq_length(self, layer_idx: Optional[int] = 0):
1801
1792
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
1802
1793
"Using the `layer_idx` argument is not supported."
1803
1794
)
1795
+ if len (self .key_cache ) == 0 :
1796
+ return 0
1804
1797
return (self .key_cache [layer_idx ][0 , 0 ].any (dim = - 1 )).sum ()
1805
1798
1806
1799
def reset (self ):
@@ -1809,6 +1802,7 @@ def reset(self):
1809
1802
# In-place ops prevent breaking the static address
1810
1803
self .key_cache [layer_idx ].zero_ ()
1811
1804
self .value_cache [layer_idx ].zero_ ()
1805
+ self .cumulative_length = [0 for _ in range (len (self .cumulative_length ))]
1812
1806
1813
1807
1814
1808
class MambaCache :
0 commit comments