Skip to content

Commit be34e6d

Browse files
committed
Fix generation with beacon tokens
The key issue was with how `flash_attn_with_kvcache` deals with the causal mask during multi-token decoding. It gets complicated around the padding tokens during generation. This may be an issue during multi token generation. Works fine for now upto 2 tokens (beacon/pad approach)
1 parent 225bc98 commit be34e6d

File tree

4 files changed

+25
-35
lines changed

4 files changed

+25
-35
lines changed

surya/common/surya/decoder/__init__.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,6 @@ def forward(
178178
query_states, key_states, cos, sin
179179
)
180180

181-
is_prefill = all(
182-
(
183-
input_shape[1] > 1,
184-
(past_key_value is None)
185-
or (past_key_value.get_seq_length(self.layer_idx) == 0),
186-
)
187-
)
188-
189181
if past_key_value is not None:
190182
# sin and cos are specific to RoPE models; cache_position needed for the static cache
191183
# cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism
@@ -212,7 +204,7 @@ def forward(
212204
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
213205
)
214206
elif self.config._attn_implementation == "flash_attention_2":
215-
if is_prefill:
207+
if prefill:
216208
attention_interface = flash_attn_prefill
217209
else:
218210
attention_interface = flash_attn_decode

surya/common/surya/flash_attn_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def flash_attn_prefill(
111111
attention_mask: torch.Tensor,
112112
dropout: float,
113113
scaling: float,
114-
sliding_window: Optional[int],
115114
query_length: int,
116115
batch_size: int,
117116
indices_k: torch.Tensor,
@@ -135,8 +134,6 @@ def flash_attn_prefill(
135134
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
136135
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
137136

138-
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if sliding_window else {}
139-
140137
# Returning None for attn_weights to match other attention interfaces
141138
flash_attn_out = _flash_attn_varlen_func(
142139
q_flash,
@@ -149,7 +146,6 @@ def flash_attn_prefill(
149146
dropout_p=dropout,
150147
softmax_scale=scaling,
151148
causal=module.is_causal,
152-
**flash_kwargs
153149
)
154150
return pad_input(flash_attn_out, indices_q, batch_size, query_length), None
155151

@@ -161,13 +157,12 @@ def flash_attn_decode(
161157
value_states: torch.Tensor,
162158
attention_mask: torch.Tensor,
163159
scaling: float,
164-
sliding_window: bool,
165160
**kwargs,
166161
):
167162
"""
168163
Wrapper for flash attention during the decode stage
169164
170-
query_states must have shape (batch_size, num_heads, 1, head_dim), 1 is the seq length in the decoding stage
165+
query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage
171166
key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)
172167
173168
This is the opposite of what is required by flash attention, but keeps parity with the HF convention
@@ -177,14 +172,12 @@ def flash_attn_decode(
177172
cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1)
178173
cache_leftpad = cache_leftpad.to(torch.int32)
179174

180-
flash_kwargs = {'window_size': (sliding_window, sliding_window)} if sliding_window else {}
181175
# Returning None for attn_weights to match other attention interfaces
182176
return _flash_attn_with_kvcache(
183177
q=query_states,
184178
k_cache=key_states,
185179
v_cache=value_states,
186180
cache_leftpad=cache_leftpad,
187-
causal=module.is_causal,
181+
causal=False,
188182
softmax_scale=scaling,
189-
**flash_kwargs
190183
), None

surya/foundation/__init__.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def maybe_insert_beacon_tokens(
230230

231231
token = input_ids.squeeze(1) # shape: [batch_size]
232232
add_beacon = (num_predicted_tokens % self.beacon_token_interval== 0).squeeze()
233-
233+
234+
# Return if no beacon tokens need to be added
235+
if torch.all(~add_beacon):
236+
return input_ids, torch.ones((input_ids.shape[0]), dtype=torch.long, device=input_ids.device)
237+
234238
# Output tensors
235239
new_input_ids = torch.full((batch_size, 2), self.device_pad_token, dtype=input_ids.dtype, device=input_ids.device)
236240

@@ -251,10 +255,11 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
251255
position_ids = current_inputs.position_ids
252256
num_predicted_tokens = current_inputs.num_predicted_tokens
253257
num_valid_tokens = current_inputs.num_valid_tokens
258+
batch_size = input_ids.shape[0]
254259

255260
# Pre-shift the attention mask based on the cache update
256261
self.kv_cache.maybe_shift_attention_mask(
257-
num_valid_tokens=num_valid_tokens, cache_idxs=list(range(input_ids.shape[0]))
262+
num_valid_tokens=num_valid_tokens, cache_idxs=list(range(batch_size))
258263
)
259264
with settings.INFERENCE_MODE():
260265
outputs = self.model(
@@ -263,7 +268,8 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
263268
position_ids=position_ids,
264269
use_cache=True,
265270
past_key_values=self.kv_cache,
266-
logits_to_keep=torch.max(num_valid_tokens).item(),
271+
# We may pass multiple input ids per batch element (right padded) and we need the original size to index into them
272+
logits_to_keep=None,
267273
prefill=False,
268274
num_valid_tokens=num_valid_tokens
269275
)
@@ -274,9 +280,12 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
274280
input_ids = processed_output.input_ids
275281
num_predicted_tokens += 1
276282

277-
# input_ids, num_valid_tokens = self.maybe_insert_beacon_tokens(input_ids, num_predicted_tokens)
278-
# TODO we should only consider position_ids upto the valid range for each batch element
279-
position_ids = position_ids[:, -1:] + torch.arange(1, input_ids.shape[1] + 1, device=input_ids.device)
283+
batch_indices = torch.arange(batch_size, device=position_ids.device)
284+
last_token_indices = (num_valid_tokens - 1)
285+
last_valid_positions = position_ids[batch_indices, last_token_indices].reshape(batch_size, 1)
286+
287+
input_ids, num_valid_tokens = self.maybe_insert_beacon_tokens(input_ids, num_predicted_tokens)
288+
position_ids = last_valid_positions + torch.arange(1, input_ids.shape[1] + 1, device=input_ids.device)
280289

281290
new_input = ContinuousBatchInput(
282291
input_ids=input_ids,
@@ -377,7 +386,7 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
377386

378387
# Process outputs
379388
# No extra tokens during prefill
380-
num_valid_tokens = torch.ones((input_ids.shape[0], 1), device=self.model.device, dtype=torch.long)
389+
num_valid_tokens = torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long)
381390
num_predicted_tokens = torch.ones((input_ids.shape[0], 1), device=self.model.device, dtype=torch.long)
382391
processed_outputs = self.process_outputs(outputs, num_valid_tokens=num_valid_tokens)
383392

surya/foundation/cache.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,10 @@ def maybe_shift_attention_mask(
9797
shift = new_text_len
9898
self._shift_attention_mask_left(cache_idx, shift)
9999
else:
100-
# We need to figure out how many text tokens to keep and where to place them
101-
keep = self.text_sliding_window - new_text_len
102-
assert keep > 0, "Cannot add more new text tokens than the sliding window"
103-
104100
# Shift entire cache left to make room for full text sliding window
105101
shift_amount = self.text_sliding_window - curr_text_cache_len
106-
if shift_amount > 0: # Cannot be negative, may be exactly 0
102+
# If this is <=0, we are already above the sliding window, so the attention mask stays the same
103+
if shift_amount > 0:
107104
self._shift_attention_mask_left(cache_idx, shift_amount)
108105

109106
# Mirrors the logic from _prefill_update
@@ -222,17 +219,16 @@ def _decode_update(
222219

223220
curr_text_cache_len = self.text_token_counts[layer_idx][cache_idx].item()
224221

225-
k_new = key_states[batch_idx, :, :new_text_len, :] # (H, new_text_len, D)
222+
k_new = key_states[batch_idx, :, :new_text_len, :]
226223
v_new = value_states[batch_idx, :, :new_text_len, :]
227224

228225
if curr_text_cache_len + new_text_len <= self.text_sliding_window:
229226
# If we are under the sliding window length, shift the entire cache left
230227
# Since we setup the max cache length with enough buffer, this will ONLY drop
231228
# left padding tokens out
232229
shift = new_text_len
233-
if curr_text_cache_len > 0:
234-
k_cache[cache_idx, :, :-shift, :] = k_cache[cache_idx, :, shift:, :].clone()
235-
v_cache[cache_idx, :, :-shift, :] = v_cache[cache_idx, :, shift:, :].clone()
230+
k_cache[cache_idx, :, :-shift, :] = k_cache[cache_idx, :, shift:, :].clone()
231+
v_cache[cache_idx, :, :-shift, :] = v_cache[cache_idx, :, shift:, :].clone()
236232
k_cache[cache_idx, :, -shift:, :] = k_new
237233
v_cache[cache_idx, :, -shift:, :] = v_new
238234

@@ -268,4 +264,4 @@ def _decode_update(
268264
self.key_cache[layer_idx] = k_cache
269265
self.value_cache[layer_idx] = v_cache
270266

271-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
267+
return self.key_cache[layer_idx], self.value_cache[layer_idx]

0 commit comments

Comments
 (0)