Skip to content

Commit 7bee1cf

Browse files
authored
feat: adding out and lse parameters to run functions to allow user allocated output buffer (#854)
cc @youkaichao
1 parent 55b75d8 commit 7bee1cf

9 files changed

+216
-45
lines changed

flashinfer/decode.py

+58-18
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_check_cached_qkv_data_type,
4646
_check_kv_layout,
4747
_check_pos_encoding_mode,
48+
_check_shape_dtype_device,
4849
_get_cache_alibi_slopes_buf,
4950
_get_cache_buf,
5051
_get_range_buf,
@@ -972,6 +973,8 @@ def run(
972973
q_scale: Optional[float] = None,
973974
k_scale: Optional[float] = None,
974975
v_scale: Optional[float] = None,
976+
out: Optional[torch.Tensor] = None,
977+
lse: Optional[torch.Tensor] = None,
975978
return_lse: Literal[False] = False,
976979
) -> torch.Tensor: ...
977980

@@ -984,6 +987,8 @@ def run(
984987
q_scale: Optional[float] = None,
985988
k_scale: Optional[float] = None,
986989
v_scale: Optional[float] = None,
990+
out: Optional[torch.Tensor] = None,
991+
lse: Optional[torch.Tensor] = None,
987992
return_lse: Literal[True] = True,
988993
) -> Tuple[torch.Tensor, torch.Tensor]: ...
989994

@@ -995,6 +1000,8 @@ def run(
9951000
q_scale: Optional[float] = None,
9961001
k_scale: Optional[float] = None,
9971002
v_scale: Optional[float] = None,
1003+
out: Optional[torch.Tensor] = None,
1004+
lse: Optional[torch.Tensor] = None,
9981005
return_lse: bool = False,
9991006
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
10001007
r"""Compute batch decode attention between query and paged kv cache.
@@ -1016,13 +1023,18 @@ def run(
10161023
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
10171024
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
10181025
``paged_kv_cache[:, 1]`` is the value-cache.
1019-
1026+
*args
1027+
Additional arguments for the custom kernel.
10201028
q_scale : Optional[float]
10211029
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
10221030
k_scale : Optional[float]
10231031
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
10241032
v_scale : Optional[float]
10251033
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
1034+
out : Optional[torch.Tensor]
1035+
The output tensor, if not provided, will be allocated internally.
1036+
lse : Optional[torch.Tensor]
1037+
The log-sum-exp of attention logits, if not provided, will be allocated internally.
10261038
return_lse : bool
10271039
Whether to return the logsumexp of attention scores, defaults to ``False``.
10281040
@@ -1061,13 +1073,21 @@ def run(
10611073
if rope_theta is None:
10621074
rope_theta = 1e4
10631075

1064-
lse = None
10651076
if return_lse:
1066-
lse = torch.empty(
1067-
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
1068-
)
1077+
if lse is None:
1078+
lse = torch.empty(
1079+
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
1080+
)
1081+
else:
1082+
_check_shape_dtype_device(
1083+
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
1084+
)
1085+
1086+
if out is None:
1087+
out = torch.empty_like(q)
1088+
else:
1089+
_check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
10691090

1070-
out = torch.empty_like(q)
10711091
if self.use_tensor_cores:
10721092
run_args = [
10731093
self._float_workspace_buffer,
@@ -1270,11 +1290,11 @@ def __init__(
12701290
Whether to enable CUDAGraph for batch decode attention, if enabled, the
12711291
auxiliary data structures will be stored as the provided buffers. The ``batch_size``
12721292
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.
1273-
1293+
12741294
use_tensor_cores : bool
12751295
Whether to use tensor cores for the computation. Will be faster for large group
12761296
size in grouped query attention. Defaults to ``False``.
1277-
1297+
12781298
paged_kv_indptr_buffer : Optional[torch.Tensor]
12791299
The user reserved buffer on GPU to store the indptr of the paged kv cache, the size
12801300
of the buffer should be ``[batch_size + 1]``.
@@ -1488,6 +1508,8 @@ def run(
14881508
q_scale: Optional[float] = None,
14891509
k_scale: Optional[float] = None,
14901510
v_scale: Optional[float] = None,
1511+
out: Optional[torch.Tensor] = None,
1512+
lse: Optional[torch.Tensor] = None,
14911513
return_lse: bool = False,
14921514
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
14931515
r"""Compute batch decode attention between query and paged kv cache.
@@ -1510,6 +1532,10 @@ def run(
15101532
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
15111533
v_scale : Optional[float]
15121534
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
1535+
out : Optional[torch.Tensor]
1536+
The output tensor, if not provided, will be allocated internally.
1537+
lse : Optional[torch.Tensor]
1538+
The log-sum-exp of attention logits, if not provided, will be allocated internally.
15131539
return_lse : bool
15141540
Whether to return the logsumexp of attention scores, defaults to ``False``.
15151541
@@ -1539,14 +1565,28 @@ def run(
15391565
rope_theta = 1e4
15401566

15411567
with self.device as device:
1542-
o = torch.empty_like(q_nope, device=device)
1543-
maybe_lse = (
1544-
torch.empty(
1545-
(q_nope.size(0), q_nope.size(1)), dtype=torch.float32, device=device
1568+
if out is None:
1569+
out = torch.empty_like(q_nope, device=device)
1570+
else:
1571+
_check_shape_dtype_device(
1572+
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
15461573
)
1547-
if return_lse
1548-
else None
1549-
)
1574+
1575+
if return_lse:
1576+
if lse is None:
1577+
lse = torch.empty(
1578+
(q_nope.size(0), q_nope.size(1)),
1579+
dtype=torch.float32,
1580+
device=device,
1581+
)
1582+
else:
1583+
_check_shape_dtype_device(
1584+
lse,
1585+
(q_nope.size(0), q_nope.size(1)),
1586+
q_nope.dtype,
1587+
q_nope.device,
1588+
"lse",
1589+
)
15501590
self._cached_module.run(
15511591
self._float_workspace_buffer,
15521592
self._int_workspace_buffer,
@@ -1558,16 +1598,16 @@ def run(
15581598
self._paged_kv_indptr_buf,
15591599
self._paged_kv_indices_buf,
15601600
self._paged_kv_last_page_len_buf,
1561-
o,
1601+
out,
15621602
sm_scale,
15631603
window_left,
15641604
logits_soft_cap,
15651605
rope_scale,
15661606
rope_theta,
1567-
maybe_lse,
1607+
lse,
15681608
get_cuda_stream(device),
15691609
)
1570-
out = [o, maybe_lse] if return_lse else [o]
1610+
out = [out, lse] if return_lse else [out]
15711611
if v_scale is not None:
15721612
out[0] *= v_scale
15731613

flashinfer/mla.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
import torch
2222

2323
from .jit import gen_batch_mla_module, get_batch_mla_uri
24-
from .utils import MaskMode, get_cuda_stream, register_custom_op, register_fake_op
24+
from .utils import (
25+
MaskMode,
26+
_check_shape_dtype_device,
27+
get_cuda_stream,
28+
register_custom_op,
29+
register_fake_op,
30+
)
2531

2632
_batch_mla_modules = {}
2733

@@ -267,6 +273,8 @@ def run(
267273
q_pe: torch.Tensor,
268274
ckv_cache: torch.Tensor,
269275
kpe_cache: torch.Tensor,
276+
out: Optional[torch.Tensor] = None,
277+
lse: Optional[torch.Tensor] = None,
270278
return_lse: bool = False,
271279
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
272280
r"""Run the MLA attention computation.
@@ -283,6 +291,10 @@ def run(
283291
kpe_cache : torch.Tensor
284292
The rope part of the kv-cache tensor, shape: ``[num_pages, page_size, head_dim_kpe]``.
285293
``head_dim_kpe`` is 64 in DeepSeek v2/v3 models.
294+
out : Optional[torch.Tensor]
295+
The output tensor, if not provided, will be allocated internally.
296+
lse : Optional[torch.Tensor]
297+
The log-sum-exp of attention logits, if not provided, will be allocated internally.
286298
return_lse : bool, optional
287299
Whether to return the log-sum-exp value, default is False.
288300
"""
@@ -292,12 +304,22 @@ def run(
292304
causal = self._causal
293305
mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value
294306
with self.device as device:
295-
o = torch.empty_like(q_nope)
296-
maybe_lse = (
297-
torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device)
298-
if return_lse
299-
else None
300-
)
307+
if out is None:
308+
out = torch.empty_like(q_nope)
309+
else:
310+
_check_shape_dtype_device(
311+
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
312+
)
313+
314+
if return_lse:
315+
if lse is None:
316+
lse = torch.empty(
317+
q_nope.shape[:2], dtype=torch.float32, device=device
318+
)
319+
else:
320+
_check_shape_dtype_device(
321+
lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse"
322+
)
301323
self._cached_module.run(
302324
self._float_workspace_buffer,
303325
self._int_workspace_buffer,
@@ -307,13 +329,13 @@ def run(
307329
ckv_cache,
308330
kpe_cache,
309331
self._kv_indices_buf,
310-
o,
311-
maybe_lse,
332+
out,
333+
lse,
312334
mask_mode,
313335
num_heads,
314336
page_size,
315337
sm_scale,
316338
get_cuda_stream(device),
317339
)
318340

319-
return (o, maybe_lse) if return_lse else o
341+
return (out, lse) if return_lse else out

0 commit comments

Comments
 (0)