@@ -161,6 +161,8 @@ def forward(
161
161
past_key_value : Optional [Cache ] = None ,
162
162
cache_position : Optional [torch .LongTensor ] = None ,
163
163
cache_idxs : Optional [List [int ]] = None ,
164
+ valid_tokens : Optional [List [int ]] = None ,
165
+ prefill : bool = False ,
164
166
** kwargs : Unpack [FlashAttentionKwargs ],
165
167
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
166
168
input_shape = hidden_states .shape [:- 1 ]
@@ -185,7 +187,15 @@ def forward(
185
187
186
188
if past_key_value is not None :
187
189
# sin and cos are specific to RoPE models; cache_position needed for the static cache
188
- cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position , "cache_idxs" : cache_idxs }
190
+ # cache_idxs, valid_tokens, and prefill add support for our new caching mechanism
191
+ cache_kwargs = {
192
+ "sin" : sin ,
193
+ "cos" : cos ,
194
+ "cache_position" : cache_position ,
195
+ "cache_idxs" : cache_idxs ,
196
+ "valid_tokens" : valid_tokens ,
197
+ "prefill" : prefill
198
+ }
189
199
key_states , value_states = past_key_value .update (
190
200
key_states , value_states , self .layer_idx , cache_kwargs
191
201
)
@@ -279,6 +289,8 @@ def forward(
279
289
use_cache : Optional [bool ] = False ,
280
290
cache_position : Optional [torch .LongTensor ] = None ,
281
291
cache_idxs : Optional [List [int ]] = None ,
292
+ valid_tokens : Optional [List [int ]] = None ,
293
+ prefill : bool = False ,
282
294
position_embeddings : Optional [
283
295
Tuple [torch .Tensor , torch .Tensor ]
284
296
] = None , # necessary, but kept here for BC
@@ -300,7 +312,9 @@ def forward(
300
312
use_cache = use_cache ,
301
313
cache_position = cache_position ,
302
314
position_embeddings = position_embeddings ,
303
- cache_idxs = cache_idxs
315
+ cache_idxs = cache_idxs ,
316
+ valid_tokens = valid_tokens ,
317
+ prefill = prefill ,
304
318
** kwargs ,
305
319
)
306
320
hidden_states = residual + hidden_states
@@ -461,6 +475,8 @@ def forward(
461
475
return_dict : Optional [bool ] = None ,
462
476
cache_position : Optional [torch .LongTensor ] = None ,
463
477
cache_idxs : Optional [List [int ]] = None ,
478
+ valid_tokens : Optional [List [int ]] = None ,
479
+ prefill : bool = False ,
464
480
** flash_attn_kwargs : Unpack [FlashAttentionKwargs ],
465
481
) -> Union [Tuple , BaseModelOutputWithPast ]:
466
482
use_cache = use_cache if use_cache is not None else self .config .use_cache
@@ -501,6 +517,8 @@ def forward(
501
517
cache_position = cache_position ,
502
518
position_embeddings = position_embeddings ,
503
519
cache_idxs = cache_idxs ,
520
+ valid_tokens = valid_tokens ,
521
+ prefill = prefill ,
504
522
** flash_attn_kwargs ,
505
523
)
506
524
0 commit comments