Skip to content

Commit f959354

Browse files
authored
typo: remove another uniform samples leftover (#937)
Looks like another typo in #912 - sorry for taking 3 PRs to fix one docstring! 🙄 ``` >>> # uniform samples for rejection sampling - >>> uniform_samples = torch.rand(batch_size, num_speculate_tokens + 1).to(0) - tensor([[0.8823, 0.9150, 0.3829], device='cuda:0') >>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0) ```
1 parent d462a9d commit f959354

File tree

1 file changed

+0
-1
lines changed

1 file changed

+0
-1
lines changed

flashinfer/sampling.py

-1
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,6 @@ def chain_speculative_sampling(
12051205
>>> # token 2 was sampled from draft model for the first token, and
12061206
>>> # token 1 was sampled from draft model for the second token
12071207
>>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0)
1208-
>>> # uniform samples for rejection sampling
12091208
>>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0)
12101209
>>> output_token_ids, output_accepted_token_num, output_accepted_token_num =\
12111210
... flashinfer.sampling.chain_speculative_sampling(

0 commit comments

Comments
 (0)