|
| 1 | +""" |
| 2 | +Copyright (c) 2025 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +from typing import Optional |
| 18 | + |
| 19 | +import torch |
| 20 | +import triton |
| 21 | +import triton.language as tl |
| 22 | + |
| 23 | + |
| 24 | +@triton.jit |
| 25 | +def _compute_padded_indptr( |
| 26 | + indptr_ptr, padded_indptr_ptr, n_rows, multiple_of, BLOCK_SIZE: tl.constexpr |
| 27 | +): |
| 28 | + pid = tl.program_id(0) |
| 29 | + block_start = pid * BLOCK_SIZE |
| 30 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 31 | + mask = offsets < n_rows |
| 32 | + |
| 33 | + # Load row lengths |
| 34 | + row_start = tl.load(indptr_ptr + offsets, mask=mask, other=0) |
| 35 | + row_end = tl.load(indptr_ptr + offsets + 1, mask=mask, other=0) |
| 36 | + row_lengths = row_end - row_start |
| 37 | + |
| 38 | + # Compute padded lengths (round up to multiple_of) |
| 39 | + padded_lengths = ((row_lengths + multiple_of - 1) // multiple_of) * multiple_of |
| 40 | + |
| 41 | + # Compute cumulative sum for padded indptr |
| 42 | + if pid == 0: |
| 43 | + # First element is always 0 |
| 44 | + tl.store(padded_indptr_ptr + 0, 0) |
| 45 | + |
| 46 | + # Store the padded lengths at the correct positions |
| 47 | + tl.store(padded_indptr_ptr + offsets + 1, padded_lengths, mask=mask) |
| 48 | + |
| 49 | + |
| 50 | +@triton.jit |
| 51 | +def _pad_ragged_tensor( |
| 52 | + ragged_tensor_ptr, |
| 53 | + padded_tensor_ptr, |
| 54 | + indptr_ptr, |
| 55 | + padded_indptr_ptr, |
| 56 | + n_rows, |
| 57 | + dim, |
| 58 | + BLOCK_SIZE: tl.constexpr, |
| 59 | + fill_zeros: tl.constexpr, |
| 60 | +): |
| 61 | + pid = tl.program_id(0) |
| 62 | + |
| 63 | + # Process one row per program |
| 64 | + if pid >= n_rows: |
| 65 | + return |
| 66 | + |
| 67 | + # Get original and padded row information |
| 68 | + row_start = tl.load(indptr_ptr + pid) |
| 69 | + row_end = tl.load(indptr_ptr + pid + 1) |
| 70 | + row_length = row_end - row_start |
| 71 | + |
| 72 | + padded_row_start = tl.load(padded_indptr_ptr + pid) |
| 73 | + padded_row_end = tl.load(padded_indptr_ptr + pid + 1) |
| 74 | + padded_row_length = padded_row_end - padded_row_start |
| 75 | + |
| 76 | + # Copy the original data |
| 77 | + for i in range(0, row_length): |
| 78 | + col_idx = i |
| 79 | + src_offset = (row_start + i) * dim |
| 80 | + dst_offset = (padded_row_start + i) * dim |
| 81 | + |
| 82 | + # Copy the entire feature vector for this position |
| 83 | + for j in range(0, dim, BLOCK_SIZE): |
| 84 | + j_offsets = j + tl.arange(0, BLOCK_SIZE) |
| 85 | + j_mask = j_offsets < dim |
| 86 | + values = tl.load(ragged_tensor_ptr + src_offset + j_offsets, mask=j_mask) |
| 87 | + tl.store(padded_tensor_ptr + dst_offset + j_offsets, values, mask=j_mask) |
| 88 | + |
| 89 | + # Zero-pad the remaining positions |
| 90 | + if fill_zeros: |
| 91 | + for i in range(row_length, padded_row_length): |
| 92 | + col_idx = i |
| 93 | + dst_offset = (padded_row_start + i) * dim |
| 94 | + |
| 95 | + # Zero out the entire feature vector for this position |
| 96 | + for j in range(0, dim, BLOCK_SIZE): |
| 97 | + j_offsets = j + tl.arange(0, BLOCK_SIZE) |
| 98 | + j_mask = j_offsets < dim |
| 99 | + tl.store(padded_tensor_ptr + dst_offset + j_offsets, 0.0, mask=j_mask) |
| 100 | + |
| 101 | + |
| 102 | +@triton.jit |
| 103 | +def _pack_ragged_tensor( |
| 104 | + padded_tensor_ptr, |
| 105 | + packed_tensor_ptr, |
| 106 | + padded_indptr_ptr, |
| 107 | + original_indptr_ptr, |
| 108 | + n_rows, |
| 109 | + dim, |
| 110 | + BLOCK_SIZE: tl.constexpr, |
| 111 | +): |
| 112 | + pid = tl.program_id(0) |
| 113 | + |
| 114 | + # Process one row per program |
| 115 | + if pid >= n_rows: |
| 116 | + return |
| 117 | + |
| 118 | + # Get original and padded row information |
| 119 | + original_row_start = tl.load(original_indptr_ptr + pid) |
| 120 | + original_row_end = tl.load(original_indptr_ptr + pid + 1) |
| 121 | + original_row_length = original_row_end - original_row_start |
| 122 | + |
| 123 | + padded_row_start = tl.load(padded_indptr_ptr + pid) |
| 124 | + |
| 125 | + # Copy only the original data (not the padding) |
| 126 | + for i in range(0, original_row_length): |
| 127 | + src_offset = (padded_row_start + i) * dim |
| 128 | + dst_offset = (original_row_start + i) * dim |
| 129 | + |
| 130 | + # Copy the entire feature vector for this position |
| 131 | + for j in range(0, dim, BLOCK_SIZE): |
| 132 | + j_offsets = j + tl.arange(0, BLOCK_SIZE) |
| 133 | + j_mask = j_offsets < dim |
| 134 | + values = tl.load(padded_tensor_ptr + src_offset + j_offsets, mask=j_mask) |
| 135 | + tl.store(packed_tensor_ptr + dst_offset + j_offsets, values, mask=j_mask) |
| 136 | + |
| 137 | + |
| 138 | +def max_power_of_2_leq(x: int) -> int: |
| 139 | + r"""Return the maximum power of 2 less than or equal to x.""" |
| 140 | + return 1 << (x - 1).bit_length() |
| 141 | + |
| 142 | + |
| 143 | +def pad_ragged_tensor_to_multiple_of( |
| 144 | + ragged_tensor: torch.Tensor, |
| 145 | + indptr: torch.Tensor, |
| 146 | + multiple_of: int, |
| 147 | + fill_zeros: bool = False, |
| 148 | + output_ragged_tensor: Optional[torch.Tensor] = None, |
| 149 | + output_indptr: Optional[torch.Tensor] = None, |
| 150 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 151 | + r"""Pad each row of ragged tensor to a multiple of ``multiple_of``. |
| 152 | +
|
| 153 | + Suppose the ragged tensor has shape (150, 1024), and the indptr is [0, 100, 150] (which means there are 2 rows, |
| 154 | + the first row has 100 columns, the second row has 50 columns), and the multiple_of is 16. |
| 155 | + We will pad the first row to 112 columns, and the second row to 64 columns. |
| 156 | + The padded ragged tensor will have shape (176, 1024), and the returned indptr will be [0, 112, 176]. |
| 157 | +
|
| 158 | + Parameters |
| 159 | + ---------- |
| 160 | + ragged_tensor: torch.Tensor |
| 161 | + The ragged tensor to pad, expected shape: (nnz, D) |
| 162 | + indptr: torch.Tensor |
| 163 | + The indptr of the ragged tensor, expected shape: (n_rows + 1,) |
| 164 | + multiple_of: int |
| 165 | + The multiple of to pad to, e.g. 256 |
| 166 | + fill_zeros: bool |
| 167 | + If True, the padded positions will be filled with zeros, otherwise they will be random values, |
| 168 | + default is False. |
| 169 | + output_ragged_tensor: Optional[torch.Tensor] |
| 170 | + If provided, the padded ragged tensor will be stored in this tensor, |
| 171 | + otherwise a new tensor will be allocated. |
| 172 | + output_indptr: Optional[torch.Tensor] |
| 173 | + If provided, the padded indptr will be stored in this tensor, |
| 174 | + otherwise a new tensor will be allocated. |
| 175 | +
|
| 176 | + Returns |
| 177 | + ------- |
| 178 | + padded_ragged_tensor: torch.Tensor |
| 179 | + The padded ragged tensor, expected shape: (n_rows, padded_nnz, D) |
| 180 | + padded_indptr: torch.Tensor |
| 181 | + The padded indptr, expected shape: (n_rows + 1,) |
| 182 | + """ |
| 183 | + # Get dimensions |
| 184 | + n_rows = indptr.shape[0] - 1 |
| 185 | + nnz = ragged_tensor.shape[0] |
| 186 | + dim = ragged_tensor.shape[1] |
| 187 | + |
| 188 | + # First compute padded indptr |
| 189 | + if output_indptr is None: |
| 190 | + padded_indptr = torch.zeros_like(indptr) |
| 191 | + else: |
| 192 | + padded_indptr = output_indptr |
| 193 | + |
| 194 | + grid_size = triton.cdiv(n_rows, 128) |
| 195 | + _compute_padded_indptr[(grid_size,)]( |
| 196 | + indptr, padded_indptr, n_rows, multiple_of, BLOCK_SIZE=128 |
| 197 | + ) |
| 198 | + |
| 199 | + # Perform exclusive scan to get final padded_indptr |
| 200 | + padded_indptr[1:] = torch.cumsum(padded_indptr[1:], dim=0) |
| 201 | + |
| 202 | + # Allocate padded tensor |
| 203 | + if output_ragged_tensor is None: |
| 204 | + total_padded_length = padded_indptr[-1].item() |
| 205 | + padded_ragged_tensor = torch.empty( |
| 206 | + (total_padded_length, dim), |
| 207 | + dtype=ragged_tensor.dtype, |
| 208 | + device=ragged_tensor.device, |
| 209 | + ) |
| 210 | + else: |
| 211 | + padded_ragged_tensor = output_ragged_tensor |
| 212 | + |
| 213 | + # Pad the tensor |
| 214 | + _pad_ragged_tensor[(n_rows,)]( |
| 215 | + ragged_tensor, |
| 216 | + padded_ragged_tensor, |
| 217 | + indptr, |
| 218 | + padded_indptr, |
| 219 | + n_rows, |
| 220 | + dim, |
| 221 | + BLOCK_SIZE=min(max_power_of_2_leq(dim), 16384), |
| 222 | + num_stages=2, |
| 223 | + fill_zeros=fill_zeros, |
| 224 | + ) |
| 225 | + |
| 226 | + return padded_ragged_tensor, padded_indptr |
| 227 | + |
| 228 | + |
| 229 | +def pack_ragged_tensor( |
| 230 | + padded_tensor: torch.Tensor, |
| 231 | + padded_indptr: torch.Tensor, |
| 232 | + original_indptr: torch.Tensor, |
| 233 | + output_tensor: Optional[torch.Tensor] = None, |
| 234 | +) -> torch.Tensor: |
| 235 | + r"""Convert a padded ragged tensor back to packed format. |
| 236 | +
|
| 237 | + This function reverses the operation of pad_ragged_tensor_to_multiple_of by |
| 238 | + removing the padding and returning the original packed tensor. |
| 239 | +
|
| 240 | + Parameters |
| 241 | + ---------- |
| 242 | + padded_tensor: torch.Tensor |
| 243 | + The padded ragged tensor, expected shape: (padded_nnz, D) |
| 244 | + padded_indptr: torch.Tensor |
| 245 | + The padded indptr, expected shape: (n_rows + 1,) |
| 246 | + original_indptr: torch.Tensor |
| 247 | + The original indptr before padding, expected shape: (n_rows + 1,) |
| 248 | + output_tensor: Optional[torch.Tensor] |
| 249 | + If provided, the packed tensor will be stored in this tensor, |
| 250 | + otherwise a new tensor will be allocated. |
| 251 | +
|
| 252 | + Returns |
| 253 | + ------- |
| 254 | + packed_tensor: torch.Tensor |
| 255 | + The packed tensor with padding removed, expected shape: (original_nnz, D) |
| 256 | + """ |
| 257 | + # Get dimensions |
| 258 | + n_rows = padded_indptr.shape[0] - 1 |
| 259 | + dim = padded_tensor.shape[1] |
| 260 | + original_nnz = original_indptr[-1].item() |
| 261 | + |
| 262 | + # Allocate output tensor if not provided |
| 263 | + if output_tensor is None: |
| 264 | + packed_tensor = torch.empty( |
| 265 | + (original_nnz, dim), |
| 266 | + dtype=padded_tensor.dtype, |
| 267 | + device=padded_tensor.device, |
| 268 | + ) |
| 269 | + else: |
| 270 | + packed_tensor = output_tensor |
| 271 | + |
| 272 | + # Pack the tensor by removing padding |
| 273 | + _pack_ragged_tensor[(n_rows,)]( |
| 274 | + padded_tensor, |
| 275 | + packed_tensor, |
| 276 | + padded_indptr, |
| 277 | + original_indptr, |
| 278 | + n_rows, |
| 279 | + dim, |
| 280 | + BLOCK_SIZE=min(max_power_of_2_leq(dim), 16384), |
| 281 | + num_stages=2, |
| 282 | + ) |
| 283 | + |
| 284 | + return packed_tensor |
0 commit comments