Skip to content

Commit a515579

Browse files
authored
Merge pull request #14 from huggingface/norope
Add support for no rope
2 parents aa8daba + 04b302a commit a515579

15 files changed

+470
-593
lines changed

src/transformers/cache_utils.py

+59-65
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ def __init__(
12041204
config.num_attention_heads
12051205
if getattr(config, "num_key_value_heads", None) is None
12061206
else config.num_key_value_heads
1207-
)
1207+
) // 8 # TODO use TP!
12081208

12091209
self.key_cache: List[torch.Tensor] = []
12101210
self.value_cache: List[torch.Tensor] = []
@@ -1663,84 +1663,75 @@ def __init__(
16631663
max_batch_size: int,
16641664
max_cache_len: Optional[int] = None,
16651665
device: Union[torch.device, str, None] = None,
1666-
dtype: torch.dtype = torch.float32,
1666+
dtype: torch.dtype = torch.bfloat16,
16671667
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
16681668
) -> None:
16691669
super().__init__()
16701670
if not hasattr(config, "sliding_window") or config.sliding_window is None:
1671-
raise ValueError(
1672-
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
1673-
"sliding window attention, please check if there is a `sliding_window` field in the model "
1674-
"config and it's not set to None."
1675-
)
1671+
self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8092)
1672+
else:
1673+
self.sliding_window = config.sliding_window
16761674
self.max_cache_len = max_cache_len
16771675
self.max_batch_size = max_batch_size
1678-
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
1679-
self.head_dim = (
1680-
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
1681-
)
1682-
1676+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
16831677
self._dtype = dtype
1684-
self.num_key_value_heads = (
1685-
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
1686-
)
16871678

1688-
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
1689-
self.is_sliding = torch.tensor(
1690-
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
1691-
)
1679+
if hasattr(config.get_text_config(), "no_rope_layers"):
1680+
self.is_sliding = torch.tensor(config.no_rope_layers)
1681+
else:
1682+
layer_switch = getattr(config, "sliding_window_pattern", 2)
1683+
self.is_sliding = torch.tensor(
1684+
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
1685+
)
1686+
16921687
self.key_cache: List[torch.Tensor] = []
16931688
self.value_cache: List[torch.Tensor] = []
1694-
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
1689+
self.cumulative_length = [0 for _ in range(config.num_hidden_layers)]
1690+
1691+
def initialise_cache_layer(self, layer_idx, key_states):
1692+
if len(self.key_cache) > layer_idx:
1693+
return
1694+
1695+
num_key_value_heads = key_states.shape[1]
1696+
device = key_states.device
1697+
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
16951698
sliding_cache_shape = (
16961699
self.max_batch_size,
1697-
self.num_key_value_heads,
1698-
min(config.sliding_window, max_cache_len),
1700+
num_key_value_heads,
1701+
self.sliding_window,
16991702
self.head_dim,
17001703
)
1701-
device = torch.device(device) if device is not None and isinstance(device, str) else None
1702-
for i in range(config.num_hidden_layers):
1703-
if layer_device_map is not None:
1704-
layer_device = layer_device_map[i]
1705-
else:
1706-
layer_device = device
1707-
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1708-
# breaks when updating the cache.
1709-
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
1710-
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
1711-
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
1712-
torch._dynamo.mark_static_address(new_layer_key_cache)
1713-
torch._dynamo.mark_static_address(new_layer_value_cache)
1714-
self.key_cache.append(new_layer_key_cache)
1715-
self.value_cache.append(new_layer_value_cache)
1704+
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1705+
# breaks when updating the cache.
1706+
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
1707+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
1708+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
1709+
torch._dynamo.mark_static_address(new_layer_key_cache)
1710+
torch._dynamo.mark_static_address(new_layer_value_cache)
1711+
self.key_cache.append(new_layer_key_cache)
1712+
self.value_cache.append(new_layer_value_cache)
17161713

17171714
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
1718-
if cache_position.shape[0] > max_cache_len:
1719-
k_out = key_states[:, :, -max_cache_len:, :]
1720-
v_out = value_states[:, :, -max_cache_len:, :]
1721-
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
1722-
self.key_cache[layer_idx] += k_out
1723-
self.value_cache[layer_idx] += v_out
1724-
# we should return the whole states instead of k_out, v_out to take the whole prompt
1725-
# into consideration when building kv cache instead of just throwing away tokens outside of the window
1726-
return key_states, value_states
1727-
1728-
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
1729-
cache_position = cache_position.clamp(0, max_cache_len - 1)
1730-
to_shift = cache_position >= max_cache_len - 1
1731-
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
1732-
k_out = k_out[:, :, indices]
1733-
v_out = v_out[:, :, indices]
1734-
1735-
k_out[:, :, cache_position] = key_states
1736-
v_out[:, :, cache_position] = value_states
1737-
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
1738-
self.key_cache[layer_idx].zero_()
1739-
self.value_cache[layer_idx].zero_()
1740-
1741-
self.key_cache[layer_idx] += k_out
1742-
self.value_cache[layer_idx] += v_out
1743-
return k_out, v_out
1715+
cumulative_length = self.cumulative_length[layer_idx]
1716+
is_full = cumulative_length >= max_cache_len
1717+
if is_full:
1718+
full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2)
1719+
full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2)
1720+
elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len:
1721+
full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2)
1722+
full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2)
1723+
else:
1724+
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
1725+
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
1726+
self.cumulative_length[layer_idx] += key_states.shape[-2]
1727+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
1728+
1729+
self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :])
1730+
self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :])
1731+
self.cumulative_length[layer_idx] += key_states.shape[-2]
1732+
# we should return the whole states instead of k_out, v_out to take the whole prompt
1733+
# into consideration when building kv cache instead of just throwing away tokens outside of the window
1734+
return full_key_states, full_value_states
17441735

17451736
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
17461737
k_out[:, :, cache_position] = key_states
@@ -1760,7 +1751,7 @@ def update(
17601751
if cache_kwargs is None:
17611752
cache_kwargs = {}
17621753
cache_position = cache_kwargs.get("cache_position")
1763-
sliding_window = cache_kwargs.get("sliding_window")
1754+
self.initialise_cache_layer(layer_idx, key_states)
17641755

17651756
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
17661757
# when the cache is initialized in the forward pass (e.g. Gemma2)
@@ -1774,7 +1765,7 @@ def update(
17741765
key_states = key_states.to(k_out.dtype)
17751766
value_states = value_states.to(v_out.dtype)
17761767

1777-
if sliding_window:
1768+
if self.is_sliding[layer_idx]:
17781769
update_fn = self._sliding_update
17791770
else:
17801771
update_fn = self._static_update
@@ -1801,6 +1792,8 @@ def get_seq_length(self, layer_idx: Optional[int] = 0):
18011792
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
18021793
"Using the `layer_idx` argument is not supported."
18031794
)
1795+
if len(self.key_cache) == 0:
1796+
return 0
18041797
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
18051798

18061799
def reset(self):
@@ -1809,6 +1802,7 @@ def reset(self):
18091802
# In-place ops prevent breaking the static address
18101803
self.key_cache[layer_idx].zero_()
18111804
self.value_cache[layer_idx].zero_()
1805+
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
18121806

18131807

18141808
class MambaCache:

src/transformers/generation/configuration_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def __init__(self, **kwargs):
416416
if isinstance(self.cache_config, dict):
417417
self.cache_config = cache_config_class.from_dict(self.cache_config)
418418
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
419+
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
419420

420421
# Parameters for manipulation of the model output logits
421422
self.temperature = kwargs.pop("temperature", 1.0)

src/transformers/generation/utils.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -3318,7 +3318,12 @@ def _sample(
33183318
os.environ["TOKENIZERS_PARALLELISM"] = "0"
33193319
model_forward = self.get_compiled_call(generation_config.compile_config)
33203320

3321-
is_prefill = True
3321+
if generation_config.prefill_chunk_size is not None:
3322+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
3323+
is_prefill = False
3324+
else:
3325+
is_prefill = True
3326+
33223327
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
33233328
# prepare model inputs
33243329
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@@ -4768,6 +4773,42 @@ def _assisted_decoding(
47684773
else:
47694774
return input_ids
47704775

4776+
def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
4777+
chunk_size = generation_config.prefill_chunk_size
4778+
# Only chunk up the token just before last, so that decoding is completely performed outside this function
4779+
# (here we simply prefill the cache)
4780+
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
4781+
4782+
if "past_key_values" not in model_kwargs:
4783+
raise ValueError("Cannot use prefill chunkink without a cache")
4784+
4785+
model_forward = self.get_compiled_call(generation_config.compile_config)
4786+
attention_mask = model_kwargs.pop("attention_mask", None)
4787+
4788+
past_length = 0
4789+
for input_chunk in input_chunks:
4790+
current_length = past_length + input_chunk.shape[-1]
4791+
# Prepare inputs
4792+
if attention_mask is not None:
4793+
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
4794+
model_kwargs["cache_position"] = torch.arange(
4795+
past_length, current_length, dtype=torch.long, device=input_chunk.device
4796+
)
4797+
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
4798+
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
4799+
4800+
# outputs = model_forward(**model_inputs, return_dict=True) TODO REACTIVATE THIS!!!
4801+
outputs = self(**model_inputs, return_dict=True)
4802+
4803+
model_kwargs["past_key_values"] = outputs.past_key_values
4804+
past_length = current_length
4805+
4806+
model_kwargs["attention_mask"] = attention_mask
4807+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
4808+
_ = model_kwargs.pop("position_ids", None)
4809+
4810+
return model_kwargs
4811+
47714812

47724813
def _speculative_sampling(
47734814
candidate_input_ids,

src/transformers/integrations/compressed_tensors.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from transformers.utils import is_torch_available
32

43

@@ -10,7 +9,8 @@
109

1110

1211
def skip(*args, **kwargs):
13-
pass
12+
pass
13+
1414

1515
class CompressedExpertsLinear(nn.Module):
1616
"""

src/transformers/integrations/flex_attention.py

+44-18
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@
3434

3535

3636
if is_torch_flex_attn_available():
37-
from torch.nn.attention.flex_attention import (
38-
BlockMask,
39-
flex_attention,
40-
)
37+
from torch.nn.attention.flex_attention import BlockMask, flex_attention
4138
from torch.nn.attention.flex_attention import (
4239
create_block_mask as create_block_causal_mask_flex,
4340
)
@@ -64,14 +61,23 @@ def __init__(self):
6461
Initialize or update the singleton instance.
6562
"""
6663
if self._is_flex_compiled is False:
67-
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
64+
self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor")
6865
self._is_flex_compiled = True
6966

7067
def __call__(self):
7168
return self._compiled_flex_attention
7269

7370

74-
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
71+
Offset = Union[torch.Tensor, int]
72+
73+
74+
def make_flex_block_causal_mask(
75+
attention_mask_2d: torch.Tensor,
76+
attention_chunk_size: Optional[int] = None,
77+
query_length=None,
78+
key_length=None,
79+
offsets: Optional[Tuple[Offset, Offset]] = None,
80+
) -> "BlockMask":
7581
"""
7682
Create a block causal document mask for a batch of sequences, both packed and unpacked.
7783
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
@@ -94,10 +100,13 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
94100
Returns:
95101
BlockMask
96102
"""
103+
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
97104
device = attention_mask_2d.device
105+
document_ids = attention_mask_2d.clone()
98106

99-
document_ids = attention_mask_2d
100-
batch_size, total_seq_len = document_ids.shape
107+
if attention_chunk_size is not None:
108+
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
109+
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
101110

102111
# Instead of passing a tensor mask, flex attention requires a mask_mod function
103112
# that determines which elements of QK^T should be included in the attention
@@ -112,18 +121,30 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
112121
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
113122
for an illustration.
114123
"""
115-
causal_mask = q_idx >= kv_idx
124+
causal_mask = q_idx >= kv_idx # not valid when decoding
116125
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
117-
padding_mask = document_ids[batch_idx, q_idx] > 0
118-
return causal_mask & document_mask & padding_mask
119-
126+
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
127+
final_mask = causal_mask & padding_mask & document_mask
128+
return final_mask
129+
130+
if offsets is not None:
131+
q_offset = offsets[0]
132+
kv_offset = offsets[1]
133+
134+
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
135+
offset_q = q_idx + q_offset
136+
offset_kv = kv_idx + kv_offset
137+
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
138+
else:
139+
mask_mod = causal_mask_mod
120140
return create_block_causal_mask_flex(
121-
mask_mod=causal_mask_mod,
122-
B=batch_size,
141+
mask_mod=mask_mod,
142+
B=1,
123143
H=None, # attention head
124-
Q_LEN=total_seq_len,
125-
KV_LEN=total_seq_len,
144+
Q_LEN=query_length,
145+
KV_LEN=key_length,
126146
device=device,
147+
_compile=True,
127148
)
128149

129150

@@ -155,6 +176,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
155176
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
156177
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
157178

179+
158180
def flex_attention_forward(
159181
module: torch.nn.Module,
160182
query: torch.Tensor,
@@ -169,7 +191,7 @@ def flex_attention_forward(
169191
block_mask = None
170192
causal_mask = None
171193
if isinstance(attention_mask, BlockMask):
172-
block_mask = attention_mask
194+
block_mask = attention_mask # ._adjust(query.shape[2], key.shape[2])
173195
else:
174196
causal_mask = attention_mask
175197

@@ -187,11 +209,14 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
187209

188210
enable_gqa = True
189211
num_local_query_heads = query.shape[1]
190-
if not((num_local_query_heads & (num_local_query_heads)) == 0):
212+
213+
# When running TP this helps:
214+
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
191215
key = repeat_kv(key, num_local_query_heads)
192216
value = repeat_kv(value, num_local_query_heads)
193217
enable_gqa = False
194218

219+
kernel_options = kwargs.get("kernel_options", None)
195220
attn_output, attention_weights = compile_friendly_flex_attention(
196221
query,
197222
key,
@@ -200,6 +225,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
200225
block_mask=block_mask,
201226
enable_gqa=enable_gqa,
202227
scale=scaling,
228+
kernel_options=kernel_options,
203229
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
204230
# For simplification, we thus always return it as no additional computations are introduced.
205231
return_lse=True,

0 commit comments

Comments
 (0)