|
30 | 30 |
|
31 | 31 |
|
32 | 32 | def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor):
|
33 |
| - r"""Category sampling from probabilities. |
| 33 | + r"""Fused GPU kernel for category sampling from probabilities. |
34 | 34 |
|
35 | 35 | Parameters
|
36 | 36 | ----------
|
@@ -75,8 +75,11 @@ def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor):
|
75 | 75 | def top_p_sampling_from_probs(
|
76 | 76 | probs: torch.Tensor, uniform_samples: torch.Tensor, top_p: float
|
77 | 77 | ):
|
78 |
| - r"""Top-p sampling (nucleus sampling) from probabilities, this operator implements |
79 |
| - GPU-based rejection sampling without explicit sorting. |
| 78 | + r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, |
| 79 | + this operator implements GPU-based rejection sampling without explicit sorting. |
| 80 | +
|
| 81 | + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, |
| 82 | + which is more efficient than the naive implementation that launches a series of kernels. |
80 | 83 |
|
81 | 84 | Parameters
|
82 | 85 | ----------
|
@@ -134,8 +137,11 @@ def top_p_sampling_from_probs(
|
134 | 137 | def top_k_sampling_from_probs(
|
135 | 138 | probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int
|
136 | 139 | ):
|
137 |
| - r"""Top-k sampling from probabilities, this operator implements GPU-based rejection sampling |
138 |
| - without explicit sorting. |
| 140 | + r"""Fused GPU kernel for top-k sampling from probabilities, |
| 141 | + this operator implements GPU-based rejection sampling without explicit sorting. |
| 142 | +
|
| 143 | + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, |
| 144 | + which is more efficient than the naive implementation that launches a series of kernels. |
139 | 145 |
|
140 | 146 | Parameters
|
141 | 147 | ----------
|
@@ -188,3 +194,96 @@ def top_k_sampling_from_probs(
|
188 | 194 | implementation usually use much fewer rounds for rejection sampling because of early stopping.
|
189 | 195 | """
|
190 | 196 | return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)
|
| 197 | + |
| 198 | + |
| 199 | +def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5): |
| 200 | + r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding. |
| 201 | +
|
| 202 | + Parameters |
| 203 | + ---------- |
| 204 | + probs: torch.Tensor |
| 205 | + Probabilities, shape ``(batch_size, num_classes)``. |
| 206 | + top_p: float |
| 207 | + The threshold for re-normalizing probabilities, should be in ``(0, 1)``. |
| 208 | + We mask out the probabilities less than `threshold` where the cumulative sum |
| 209 | + of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities. |
| 210 | + eps: float |
| 211 | + The epsilon value for numerical stability. |
| 212 | +
|
| 213 | + Returns |
| 214 | + ------- |
| 215 | + renorm_probs: torch.Tensor |
| 216 | + Renormalized probabilities, shape ``(batch_size, num_classes)``. |
| 217 | +
|
| 218 | + This combination of ``top_p_renorm_prob`` and ``sampling_from_probs`` should be equivalent to |
| 219 | + ``top_p_sampling_from_probs``. |
| 220 | + """ |
| 221 | + return _kernels.top_p_renorm_prob(probs, top_p, eps) |
| 222 | + |
| 223 | + |
| 224 | +def top_k_renorm_prob(probs: torch.Tensor, top_k: int, eps: float = 1e-5): |
| 225 | + r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding. |
| 226 | +
|
| 227 | + Parameters |
| 228 | + ---------- |
| 229 | + probs: torch.Tensor |
| 230 | + Probabilities, shape ``(batch_size, num_classes)``. |
| 231 | + top_k: int |
| 232 | + The threshold for re-normalizing probabilities, should be in ``(0, num_classes)``. |
| 233 | + We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities. |
| 234 | + eps: float |
| 235 | + The epsilon value for numerical stability. |
| 236 | +
|
| 237 | + Returns |
| 238 | + ------- |
| 239 | + renorm_probs: torch.Tensor |
| 240 | + Renormalized probabilities, shape ``(batch_size, num_classes)``. |
| 241 | +
|
| 242 | + Note |
| 243 | + ---- |
| 244 | + This combination of ``top_k_renorm_prob`` and ``sampling_from_probs`` should be equivalent to |
| 245 | + ``top_k_sampling_from_probs``. |
| 246 | + """ |
| 247 | + return _kernels.top_k_renorm_prob(probs, top_k, eps) |
| 248 | + |
| 249 | + |
| 250 | +def chain_speculative_sampling( |
| 251 | + draft_probs, |
| 252 | + draft_token_ids, |
| 253 | + uniform_samples, |
| 254 | + target_probs, |
| 255 | +): |
| 256 | + r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in |
| 257 | + paper `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_), |
| 258 | + where the draft model generates a sequence(chain) of tokens for each request. |
| 259 | +
|
| 260 | + Parameters |
| 261 | + ---------- |
| 262 | + draft_probs: torch.Tensor |
| 263 | + The probability over vocabulary generated by draft model. |
| 264 | + Shape: ``(batch_size, num_speculate_tokens, vocab_size)`` |
| 265 | + draft_token_ids: torch.Tensor |
| 266 | + The draft model's generated token indices. |
| 267 | + Shape: ``(batch_size, num_specutate_tokens)`` |
| 268 | + uniform_samples: torch.Tensor |
| 269 | + The uniform samples used as needle for sampling, shape ``(batch_size, num_speculate_tokens + 1)``. |
| 270 | + Expected to be uniformly distributed in ``[0, 1)``. |
| 271 | + target_probs: torch.Tensor |
| 272 | + The probability over vocabulary generated by target model. |
| 273 | + Compared to input :attr:`draft_probs`, the target model's probability has an additional |
| 274 | + slot at the end because the target model will generate one more token than the draft model. |
| 275 | + Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)`` |
| 276 | +
|
| 277 | + Returns |
| 278 | + ------- |
| 279 | + output_token_ids: torch.Tensor |
| 280 | + The output token indices verified by the target model, rejected samples are |
| 281 | + padded with ``-1``. |
| 282 | + Compared to input :attr:`draft_token_ids`, the output tensor has an additional |
| 283 | + token index at the end for the final token, if all previous tokens are accepted, |
| 284 | + another "bonus" token will be sampled from the target model's probability. |
| 285 | + Shape: (batch_size, num_specutate_tokens + 1) |
| 286 | + """ |
| 287 | + return _kernels.chain_speculative_sampling( |
| 288 | + draft_probs, draft_token_ids, uniform_samples, target_probs |
| 289 | + ) |
0 commit comments