Skip to content

Commit 890df95

Browse files
authored
feat: ragged tensor padding kernel for blackwell kernel alignment (#1025)
Some of the blackwell kernels require each row's length be padded to multiple of 128/256, this PR adds of kernel for preprocessing data. cc @cyx-6 .
1 parent 83d1c74 commit 890df95

File tree

4 files changed

+487
-0
lines changed

4 files changed

+487
-0
lines changed

benchmarks/bench_pad_ragged_tensor.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import argparse
2+
from typing import cast
3+
4+
import torch
5+
from triton.testing import do_bench
6+
7+
from flashinfer.triton import pad_ragged_tensor_to_multiple_of
8+
9+
10+
def bench_pad_ragged_tensor_to_multiple_of(batch_size, qkv_len, d, multiple_of):
11+
device = torch.device("cuda:0")
12+
torch.manual_seed(42)
13+
14+
indptr = torch.arange(0, (batch_size + 1) * qkv_len, qkv_len, device=device)
15+
ragged_tensor = torch.randn((indptr[-1], d), device=device)
16+
17+
ms = do_bench(
18+
lambda: pad_ragged_tensor_to_multiple_of(ragged_tensor, indptr, multiple_of)
19+
)
20+
mem_bandwidth_gb_s = (
21+
2 * ragged_tensor.numel() * ragged_tensor.element_size() / ms * 1e-6
22+
)
23+
24+
print(
25+
f"batch_size={batch_size}, qkv_len={qkv_len}, d={d}, multiple_of={multiple_of}, ms={ms}, mem_bandwidth={mem_bandwidth_gb_s} GB/s"
26+
)
27+
28+
29+
if __name__ == "__main__":
30+
for batch_size in [11, 47, 101]:
31+
for qkv_len in [500, 1017, 8011]:
32+
for d in [2048, 4096, 16384]:
33+
for multiple_of in [128]:
34+
bench_pad_ragged_tensor_to_multiple_of(
35+
batch_size, qkv_len, d, multiple_of
36+
)

flashinfer/triton/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
from . import cascade # noqa: F401
22
from . import sm_constraint_gemm # noqa: F401
3+
from .format_conversion import pack_ragged_tensor as pack_ragged_tensor
4+
from .format_conversion import (
5+
pad_ragged_tensor_to_multiple_of as pad_ragged_tensor_to_multiple_of,
6+
)
+284
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Optional
18+
19+
import torch
20+
import triton
21+
import triton.language as tl
22+
23+
24+
@triton.jit
25+
def _compute_padded_indptr(
26+
indptr_ptr, padded_indptr_ptr, n_rows, multiple_of, BLOCK_SIZE: tl.constexpr
27+
):
28+
pid = tl.program_id(0)
29+
block_start = pid * BLOCK_SIZE
30+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
31+
mask = offsets < n_rows
32+
33+
# Load row lengths
34+
row_start = tl.load(indptr_ptr + offsets, mask=mask, other=0)
35+
row_end = tl.load(indptr_ptr + offsets + 1, mask=mask, other=0)
36+
row_lengths = row_end - row_start
37+
38+
# Compute padded lengths (round up to multiple_of)
39+
padded_lengths = ((row_lengths + multiple_of - 1) // multiple_of) * multiple_of
40+
41+
# Compute cumulative sum for padded indptr
42+
if pid == 0:
43+
# First element is always 0
44+
tl.store(padded_indptr_ptr + 0, 0)
45+
46+
# Store the padded lengths at the correct positions
47+
tl.store(padded_indptr_ptr + offsets + 1, padded_lengths, mask=mask)
48+
49+
50+
@triton.jit
51+
def _pad_ragged_tensor(
52+
ragged_tensor_ptr,
53+
padded_tensor_ptr,
54+
indptr_ptr,
55+
padded_indptr_ptr,
56+
n_rows,
57+
dim,
58+
BLOCK_SIZE: tl.constexpr,
59+
fill_zeros: tl.constexpr,
60+
):
61+
pid = tl.program_id(0)
62+
63+
# Process one row per program
64+
if pid >= n_rows:
65+
return
66+
67+
# Get original and padded row information
68+
row_start = tl.load(indptr_ptr + pid)
69+
row_end = tl.load(indptr_ptr + pid + 1)
70+
row_length = row_end - row_start
71+
72+
padded_row_start = tl.load(padded_indptr_ptr + pid)
73+
padded_row_end = tl.load(padded_indptr_ptr + pid + 1)
74+
padded_row_length = padded_row_end - padded_row_start
75+
76+
# Copy the original data
77+
for i in range(0, row_length):
78+
col_idx = i
79+
src_offset = (row_start + i) * dim
80+
dst_offset = (padded_row_start + i) * dim
81+
82+
# Copy the entire feature vector for this position
83+
for j in range(0, dim, BLOCK_SIZE):
84+
j_offsets = j + tl.arange(0, BLOCK_SIZE)
85+
j_mask = j_offsets < dim
86+
values = tl.load(ragged_tensor_ptr + src_offset + j_offsets, mask=j_mask)
87+
tl.store(padded_tensor_ptr + dst_offset + j_offsets, values, mask=j_mask)
88+
89+
# Zero-pad the remaining positions
90+
if fill_zeros:
91+
for i in range(row_length, padded_row_length):
92+
col_idx = i
93+
dst_offset = (padded_row_start + i) * dim
94+
95+
# Zero out the entire feature vector for this position
96+
for j in range(0, dim, BLOCK_SIZE):
97+
j_offsets = j + tl.arange(0, BLOCK_SIZE)
98+
j_mask = j_offsets < dim
99+
tl.store(padded_tensor_ptr + dst_offset + j_offsets, 0.0, mask=j_mask)
100+
101+
102+
@triton.jit
103+
def _pack_ragged_tensor(
104+
padded_tensor_ptr,
105+
packed_tensor_ptr,
106+
padded_indptr_ptr,
107+
original_indptr_ptr,
108+
n_rows,
109+
dim,
110+
BLOCK_SIZE: tl.constexpr,
111+
):
112+
pid = tl.program_id(0)
113+
114+
# Process one row per program
115+
if pid >= n_rows:
116+
return
117+
118+
# Get original and padded row information
119+
original_row_start = tl.load(original_indptr_ptr + pid)
120+
original_row_end = tl.load(original_indptr_ptr + pid + 1)
121+
original_row_length = original_row_end - original_row_start
122+
123+
padded_row_start = tl.load(padded_indptr_ptr + pid)
124+
125+
# Copy only the original data (not the padding)
126+
for i in range(0, original_row_length):
127+
src_offset = (padded_row_start + i) * dim
128+
dst_offset = (original_row_start + i) * dim
129+
130+
# Copy the entire feature vector for this position
131+
for j in range(0, dim, BLOCK_SIZE):
132+
j_offsets = j + tl.arange(0, BLOCK_SIZE)
133+
j_mask = j_offsets < dim
134+
values = tl.load(padded_tensor_ptr + src_offset + j_offsets, mask=j_mask)
135+
tl.store(packed_tensor_ptr + dst_offset + j_offsets, values, mask=j_mask)
136+
137+
138+
def max_power_of_2_leq(x: int) -> int:
139+
r"""Return the maximum power of 2 less than or equal to x."""
140+
return 1 << (x - 1).bit_length()
141+
142+
143+
def pad_ragged_tensor_to_multiple_of(
144+
ragged_tensor: torch.Tensor,
145+
indptr: torch.Tensor,
146+
multiple_of: int,
147+
fill_zeros: bool = False,
148+
output_ragged_tensor: Optional[torch.Tensor] = None,
149+
output_indptr: Optional[torch.Tensor] = None,
150+
) -> tuple[torch.Tensor, torch.Tensor]:
151+
r"""Pad each row of ragged tensor to a multiple of ``multiple_of``.
152+
153+
Suppose the ragged tensor has shape (150, 1024), and the indptr is [0, 100, 150] (which means there are 2 rows,
154+
the first row has 100 columns, the second row has 50 columns), and the multiple_of is 16.
155+
We will pad the first row to 112 columns, and the second row to 64 columns.
156+
The padded ragged tensor will have shape (176, 1024), and the returned indptr will be [0, 112, 176].
157+
158+
Parameters
159+
----------
160+
ragged_tensor: torch.Tensor
161+
The ragged tensor to pad, expected shape: (nnz, D)
162+
indptr: torch.Tensor
163+
The indptr of the ragged tensor, expected shape: (n_rows + 1,)
164+
multiple_of: int
165+
The multiple of to pad to, e.g. 256
166+
fill_zeros: bool
167+
If True, the padded positions will be filled with zeros, otherwise they will be random values,
168+
default is False.
169+
output_ragged_tensor: Optional[torch.Tensor]
170+
If provided, the padded ragged tensor will be stored in this tensor,
171+
otherwise a new tensor will be allocated.
172+
output_indptr: Optional[torch.Tensor]
173+
If provided, the padded indptr will be stored in this tensor,
174+
otherwise a new tensor will be allocated.
175+
176+
Returns
177+
-------
178+
padded_ragged_tensor: torch.Tensor
179+
The padded ragged tensor, expected shape: (n_rows, padded_nnz, D)
180+
padded_indptr: torch.Tensor
181+
The padded indptr, expected shape: (n_rows + 1,)
182+
"""
183+
# Get dimensions
184+
n_rows = indptr.shape[0] - 1
185+
nnz = ragged_tensor.shape[0]
186+
dim = ragged_tensor.shape[1]
187+
188+
# First compute padded indptr
189+
if output_indptr is None:
190+
padded_indptr = torch.zeros_like(indptr)
191+
else:
192+
padded_indptr = output_indptr
193+
194+
grid_size = triton.cdiv(n_rows, 128)
195+
_compute_padded_indptr[(grid_size,)](
196+
indptr, padded_indptr, n_rows, multiple_of, BLOCK_SIZE=128
197+
)
198+
199+
# Perform exclusive scan to get final padded_indptr
200+
padded_indptr[1:] = torch.cumsum(padded_indptr[1:], dim=0)
201+
202+
# Allocate padded tensor
203+
if output_ragged_tensor is None:
204+
total_padded_length = padded_indptr[-1].item()
205+
padded_ragged_tensor = torch.empty(
206+
(total_padded_length, dim),
207+
dtype=ragged_tensor.dtype,
208+
device=ragged_tensor.device,
209+
)
210+
else:
211+
padded_ragged_tensor = output_ragged_tensor
212+
213+
# Pad the tensor
214+
_pad_ragged_tensor[(n_rows,)](
215+
ragged_tensor,
216+
padded_ragged_tensor,
217+
indptr,
218+
padded_indptr,
219+
n_rows,
220+
dim,
221+
BLOCK_SIZE=min(max_power_of_2_leq(dim), 16384),
222+
num_stages=2,
223+
fill_zeros=fill_zeros,
224+
)
225+
226+
return padded_ragged_tensor, padded_indptr
227+
228+
229+
def pack_ragged_tensor(
230+
padded_tensor: torch.Tensor,
231+
padded_indptr: torch.Tensor,
232+
original_indptr: torch.Tensor,
233+
output_tensor: Optional[torch.Tensor] = None,
234+
) -> torch.Tensor:
235+
r"""Convert a padded ragged tensor back to packed format.
236+
237+
This function reverses the operation of pad_ragged_tensor_to_multiple_of by
238+
removing the padding and returning the original packed tensor.
239+
240+
Parameters
241+
----------
242+
padded_tensor: torch.Tensor
243+
The padded ragged tensor, expected shape: (padded_nnz, D)
244+
padded_indptr: torch.Tensor
245+
The padded indptr, expected shape: (n_rows + 1,)
246+
original_indptr: torch.Tensor
247+
The original indptr before padding, expected shape: (n_rows + 1,)
248+
output_tensor: Optional[torch.Tensor]
249+
If provided, the packed tensor will be stored in this tensor,
250+
otherwise a new tensor will be allocated.
251+
252+
Returns
253+
-------
254+
packed_tensor: torch.Tensor
255+
The packed tensor with padding removed, expected shape: (original_nnz, D)
256+
"""
257+
# Get dimensions
258+
n_rows = padded_indptr.shape[0] - 1
259+
dim = padded_tensor.shape[1]
260+
original_nnz = original_indptr[-1].item()
261+
262+
# Allocate output tensor if not provided
263+
if output_tensor is None:
264+
packed_tensor = torch.empty(
265+
(original_nnz, dim),
266+
dtype=padded_tensor.dtype,
267+
device=padded_tensor.device,
268+
)
269+
else:
270+
packed_tensor = output_tensor
271+
272+
# Pack the tensor by removing padding
273+
_pack_ragged_tensor[(n_rows,)](
274+
padded_tensor,
275+
packed_tensor,
276+
padded_indptr,
277+
original_indptr,
278+
n_rows,
279+
dim,
280+
BLOCK_SIZE=min(max_power_of_2_leq(dim), 16384),
281+
num_stages=2,
282+
)
283+
284+
return packed_tensor

0 commit comments

Comments
 (0)