@@ -5,7 +5,8 @@ use std::{
5
5
sync:: { Arc , Mutex } ,
6
6
} ;
7
7
8
- use candle_core:: { Device , Error , Result , Tensor , D } ;
8
+ use candle_core:: { DType , Device , Error , Result , Tensor , D } ;
9
+ use mistralrs_quant:: { CumSumOp , SortOp } ;
9
10
#[ cfg( feature = "pyo3_macros" ) ]
10
11
use pyo3:: pyclass;
11
12
@@ -329,6 +330,160 @@ impl Sampler {
329
330
} )
330
331
}
331
332
333
+ #[ allow( unused) ]
334
+ fn sample_fast (
335
+ & self ,
336
+ logits : Tensor ,
337
+ context : & [ u32 ] ,
338
+ return_logprobs : bool ,
339
+ top_k : i64 ,
340
+ top_p : f64 ,
341
+ min_p : f64 ,
342
+ ) -> Result < Logprobs > {
343
+ let mut probs = logits. to_dtype ( DType :: F32 ) ?;
344
+
345
+ for processor in & self . logits_processors {
346
+ probs = processor. apply ( & probs, context) ?;
347
+ }
348
+
349
+ let context = Tensor :: new ( context, logits. device ( ) ) ?;
350
+ let mut counts = logits. zeros_like ( ) ?;
351
+ counts = counts. scatter_add (
352
+ & context,
353
+ & context. ones_like ( ) ?. to_dtype ( counts. dtype ( ) ) ?,
354
+ D :: Minus1 ,
355
+ ) ?;
356
+
357
+ let presence = counts
358
+ . gt ( 0. ) ?
359
+ . where_cond ( & counts. ones_like ( ) ?, & counts. zeros_like ( ) ?) ?;
360
+
361
+ match self . frequency_penalty {
362
+ Some ( freq_penalty) if freq_penalty != 0. => {
363
+ probs = ( probs - ( freq_penalty as f64 * counts) ?) ?;
364
+ }
365
+ _ => ( ) ,
366
+ }
367
+
368
+ match self . presence_penalty {
369
+ Some ( pres_penalty) if pres_penalty != 0. => {
370
+ probs = ( probs - ( pres_penalty as f64 * presence) ?) ?;
371
+ }
372
+ _ => ( ) ,
373
+ }
374
+
375
+ probs = candle_nn:: ops:: softmax_last_dim ( & ( probs / self . temperature . unwrap_or ( 1. ) ) ?) ?;
376
+
377
+ // Top-K
378
+ if top_k > 0 {
379
+ let sorted_values = probs. fast_sort_asc ( D :: Minus1 ) ?;
380
+ let topk_values = sorted_values. narrow (
381
+ D :: Minus1 ,
382
+ sorted_values. dim ( D :: Minus1 ) ? - top_k as usize ,
383
+ top_k as usize ,
384
+ ) ?;
385
+
386
+ // select the kth largest value as threshold
387
+ let threshold = topk_values. get_on_dim ( D :: Minus1 , 0 ) ?. unsqueeze ( 0 ) ?;
388
+ let mask_topk = probs. broadcast_ge ( & threshold) ?;
389
+ probs = mask_topk. where_cond ( & probs, & Tensor :: zeros_like ( & probs) ?) ?;
390
+ }
391
+
392
+ // Top-P (nucleus)
393
+ if top_p > 0.0 && top_p < 1.0 {
394
+ let sorted_probs = probs. fast_sort_asc ( D :: Minus1 ) ?;
395
+
396
+ let cumsum = sorted_probs. fast_cumsum ( D :: Minus1 ) ?;
397
+
398
+ let mask_topp = cumsum. le ( top_p) ?;
399
+
400
+ let masked_sorted =
401
+ mask_topp. where_cond ( & sorted_probs, & Tensor :: zeros_like ( & sorted_probs) ?) ?;
402
+
403
+ let threshold = masked_sorted. max ( D :: Minus1 ) ?;
404
+ let threshold = threshold. unsqueeze ( D :: Minus1 ) ?;
405
+ let mask_full = probs. broadcast_ge ( & threshold) ?;
406
+ probs = mask_full. where_cond ( & probs, & Tensor :: zeros_like ( & probs) ?) ?;
407
+ }
408
+
409
+ // Min-P
410
+ if min_p > 0.0 && min_p < 1.0 {
411
+ let max_vals = probs. max ( D :: Minus1 ) ?;
412
+ let threshold_min = ( max_vals. unsqueeze ( D :: Minus1 ) ? * min_p) ?;
413
+ let mask_minp = probs. broadcast_gt ( & threshold_min) ?;
414
+ probs = mask_minp. where_cond ( & probs, & Tensor :: zeros_like ( & probs) ?) ?;
415
+ }
416
+
417
+ let next_token = probs. argmax ( D :: Minus1 ) ?. to_scalar :: < u32 > ( ) ?;
418
+
419
+ // Extract the top‑n log‑probs if the caller asked for them.
420
+ let ( top_logprobs, logprob) = if return_logprobs {
421
+ let k = self . top_n_logprobs ;
422
+
423
+ let sorted_values = probs. fast_sort_asc ( D :: Minus1 ) ?;
424
+ let topk_values = sorted_values
425
+ . narrow (
426
+ D :: Minus1 ,
427
+ sorted_values. dim ( D :: Minus1 ) ? - top_k as usize ,
428
+ top_k as usize ,
429
+ ) ?
430
+ . to_vec1 :: < f32 > ( ) ?;
431
+
432
+ let sorted_idxs = probs. fast_argsort_asc ( D :: Minus1 ) ?;
433
+ let topk_idxs = sorted_idxs
434
+ . narrow (
435
+ D :: Minus1 ,
436
+ sorted_values. dim ( D :: Minus1 ) ? - top_k as usize ,
437
+ top_k as usize ,
438
+ ) ?
439
+ . to_vec1 :: < u32 > ( ) ?;
440
+
441
+ let mut result = Vec :: with_capacity ( k) ;
442
+ if let Some ( tokenizer) = & self . tokenizer {
443
+ for ( prob, token) in topk_values. iter ( ) . zip ( topk_idxs) {
444
+ let decoded = tokenizer
445
+ . decode ( & [ token] , false )
446
+ . map_err ( |e| Error :: Msg ( e. to_string ( ) ) ) ?;
447
+ result. push ( TopLogprob {
448
+ token,
449
+ logprob : prob. log ( 10.0 ) ,
450
+ bytes : Some ( decoded) ,
451
+ } ) ;
452
+ }
453
+ } else {
454
+ for ( prob, token) in topk_values. iter ( ) . zip ( topk_idxs) {
455
+ result. push ( TopLogprob {
456
+ token,
457
+ logprob : prob. log ( 10.0 ) ,
458
+ bytes : None ,
459
+ } ) ;
460
+ }
461
+ }
462
+
463
+ let logprob = result. last ( ) . map ( |res| res. logprob ) . unwrap_or ( 1. ) ;
464
+
465
+ ( Some ( result) , logprob)
466
+ } else {
467
+ ( None , 1. )
468
+ } ;
469
+
470
+ let bytes = if let Some ( tokenizer) = & self . tokenizer {
471
+ Some (
472
+ tokenizer
473
+ . decode ( & [ next_token] , false )
474
+ . map_err ( |x| Error :: Msg ( x. to_string ( ) ) ) ?,
475
+ )
476
+ } else {
477
+ None
478
+ } ;
479
+
480
+ Ok ( Logprobs {
481
+ token : next_token,
482
+ logprob,
483
+ top_logprobs,
484
+ bytes,
485
+ } )
486
+ }
332
487
fn sample_speculative_top_kp_min_p (
333
488
& self ,
334
489
logits : Tensor ,
@@ -623,6 +778,7 @@ impl Sampler {
623
778
Ok ( ( ) )
624
779
}
625
780
781
+ #[ allow( unused) ]
626
782
/// Sample the provided tokens.
627
783
///
628
784
/// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
@@ -635,6 +791,16 @@ impl Sampler {
635
791
rng : Arc < Mutex < Isaac64Rng > > ,
636
792
sample_speculative : bool ,
637
793
) -> Result < Logprobs > {
794
+ #[ cfg( feature = "metal" ) ]
795
+ return self . sample_fast (
796
+ logits,
797
+ context,
798
+ return_logprobs,
799
+ self . top_k ,
800
+ self . top_p ,
801
+ self . min_p ,
802
+ ) ;
803
+
638
804
let logits = logits. to_vec1 ( ) ?;
639
805
let mut logits = self . apply_penalties ( logits, context) ?;
640
806
for processor in & self . logits_processors {
0 commit comments