45
45
_check_cached_qkv_data_type ,
46
46
_check_kv_layout ,
47
47
_check_pos_encoding_mode ,
48
+ _check_shape_dtype_device ,
48
49
_get_cache_alibi_slopes_buf ,
49
50
_get_cache_buf ,
50
51
_get_range_buf ,
@@ -972,6 +973,8 @@ def run(
972
973
q_scale : Optional [float ] = None ,
973
974
k_scale : Optional [float ] = None ,
974
975
v_scale : Optional [float ] = None ,
976
+ out : Optional [torch .Tensor ] = None ,
977
+ lse : Optional [torch .Tensor ] = None ,
975
978
return_lse : Literal [False ] = False ,
976
979
) -> torch .Tensor : ...
977
980
@@ -984,6 +987,8 @@ def run(
984
987
q_scale : Optional [float ] = None ,
985
988
k_scale : Optional [float ] = None ,
986
989
v_scale : Optional [float ] = None ,
990
+ out : Optional [torch .Tensor ] = None ,
991
+ lse : Optional [torch .Tensor ] = None ,
987
992
return_lse : Literal [True ] = True ,
988
993
) -> Tuple [torch .Tensor , torch .Tensor ]: ...
989
994
@@ -995,6 +1000,8 @@ def run(
995
1000
q_scale : Optional [float ] = None ,
996
1001
k_scale : Optional [float ] = None ,
997
1002
v_scale : Optional [float ] = None ,
1003
+ out : Optional [torch .Tensor ] = None ,
1004
+ lse : Optional [torch .Tensor ] = None ,
998
1005
return_lse : bool = False ,
999
1006
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
1000
1007
r"""Compute batch decode attention between query and paged kv cache.
@@ -1016,13 +1023,18 @@ def run(
1016
1023
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
1017
1024
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
1018
1025
``paged_kv_cache[:, 1]`` is the value-cache.
1019
-
1026
+ *args
1027
+ Additional arguments for the custom kernel.
1020
1028
q_scale : Optional[float]
1021
1029
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
1022
1030
k_scale : Optional[float]
1023
1031
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
1024
1032
v_scale : Optional[float]
1025
1033
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
1034
+ out : Optional[torch.Tensor]
1035
+ The output tensor, if not provided, will be allocated internally.
1036
+ lse : Optional[torch.Tensor]
1037
+ The log-sum-exp of attention logits, if not provided, will be allocated internally.
1026
1038
return_lse : bool
1027
1039
Whether to return the logsumexp of attention scores, defaults to ``False``.
1028
1040
@@ -1061,13 +1073,21 @@ def run(
1061
1073
if rope_theta is None :
1062
1074
rope_theta = 1e4
1063
1075
1064
- lse = None
1065
1076
if return_lse :
1066
- lse = torch .empty (
1067
- (q .size (0 ), q .size (1 )), dtype = torch .float32 , device = q .device
1068
- )
1077
+ if lse is None :
1078
+ lse = torch .empty (
1079
+ (q .size (0 ), q .size (1 )), dtype = torch .float32 , device = q .device
1080
+ )
1081
+ else :
1082
+ _check_shape_dtype_device (
1083
+ lse , (q .size (0 ), q .size (1 )), torch .float32 , q .device , "lse"
1084
+ )
1085
+
1086
+ if out is None :
1087
+ out = torch .empty_like (q )
1088
+ else :
1089
+ _check_shape_dtype_device (out , q .shape , q .dtype , q .device , "out" )
1069
1090
1070
- out = torch .empty_like (q )
1071
1091
if self .use_tensor_cores :
1072
1092
run_args = [
1073
1093
self ._float_workspace_buffer ,
@@ -1270,11 +1290,11 @@ def __init__(
1270
1290
Whether to enable CUDAGraph for batch decode attention, if enabled, the
1271
1291
auxiliary data structures will be stored as the provided buffers. The ``batch_size``
1272
1292
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.
1273
-
1293
+
1274
1294
use_tensor_cores : bool
1275
1295
Whether to use tensor cores for the computation. Will be faster for large group
1276
1296
size in grouped query attention. Defaults to ``False``.
1277
-
1297
+
1278
1298
paged_kv_indptr_buffer : Optional[torch.Tensor]
1279
1299
The user reserved buffer on GPU to store the indptr of the paged kv cache, the size
1280
1300
of the buffer should be ``[batch_size + 1]``.
@@ -1488,6 +1508,8 @@ def run(
1488
1508
q_scale : Optional [float ] = None ,
1489
1509
k_scale : Optional [float ] = None ,
1490
1510
v_scale : Optional [float ] = None ,
1511
+ out : Optional [torch .Tensor ] = None ,
1512
+ lse : Optional [torch .Tensor ] = None ,
1491
1513
return_lse : bool = False ,
1492
1514
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
1493
1515
r"""Compute batch decode attention between query and paged kv cache.
@@ -1510,6 +1532,10 @@ def run(
1510
1532
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
1511
1533
v_scale : Optional[float]
1512
1534
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
1535
+ out : Optional[torch.Tensor]
1536
+ The output tensor, if not provided, will be allocated internally.
1537
+ lse : Optional[torch.Tensor]
1538
+ The log-sum-exp of attention logits, if not provided, will be allocated internally.
1513
1539
return_lse : bool
1514
1540
Whether to return the logsumexp of attention scores, defaults to ``False``.
1515
1541
@@ -1539,14 +1565,28 @@ def run(
1539
1565
rope_theta = 1e4
1540
1566
1541
1567
with self .device as device :
1542
- o = torch .empty_like (q_nope , device = device )
1543
- maybe_lse = (
1544
- torch .empty (
1545
- (q_nope .size (0 ), q_nope .size (1 )), dtype = torch .float32 , device = device
1568
+ if out is None :
1569
+ out = torch .empty_like (q_nope , device = device )
1570
+ else :
1571
+ _check_shape_dtype_device (
1572
+ out , q_nope .shape , q_nope .dtype , q_nope .device , "out"
1546
1573
)
1547
- if return_lse
1548
- else None
1549
- )
1574
+
1575
+ if return_lse :
1576
+ if lse is None :
1577
+ lse = torch .empty (
1578
+ (q_nope .size (0 ), q_nope .size (1 )),
1579
+ dtype = torch .float32 ,
1580
+ device = device ,
1581
+ )
1582
+ else :
1583
+ _check_shape_dtype_device (
1584
+ lse ,
1585
+ (q_nope .size (0 ), q_nope .size (1 )),
1586
+ q_nope .dtype ,
1587
+ q_nope .device ,
1588
+ "lse" ,
1589
+ )
1550
1590
self ._cached_module .run (
1551
1591
self ._float_workspace_buffer ,
1552
1592
self ._int_workspace_buffer ,
@@ -1558,16 +1598,16 @@ def run(
1558
1598
self ._paged_kv_indptr_buf ,
1559
1599
self ._paged_kv_indices_buf ,
1560
1600
self ._paged_kv_last_page_len_buf ,
1561
- o ,
1601
+ out ,
1562
1602
sm_scale ,
1563
1603
window_left ,
1564
1604
logits_soft_cap ,
1565
1605
rope_scale ,
1566
1606
rope_theta ,
1567
- maybe_lse ,
1607
+ lse ,
1568
1608
get_cuda_stream (device ),
1569
1609
)
1570
- out = [o , maybe_lse ] if return_lse else [o ]
1610
+ out = [out , lse ] if return_lse else [out ]
1571
1611
if v_scale is not None :
1572
1612
out [0 ] *= v_scale
1573
1613
0 commit comments