Skip to content

Commit a97de2b

Browse files
authored
New, fast sampler for Metal! (#1327)
* Show TTFT * Use LRU prefix cacher * Faster prefix cacher * A bit of gpu sampling * Minp but cpu for now * Metal fast cumsum impl * Sampling with fast topp kernel * Hmm not perfect * Add metal sort kernels * Tmp * Add single block sort * Add most of multi block sort, just need copy op * Add copy kernels * Expose kernels * Add a test * Ok it works * Structure things * Add caching * Rename * Cpu is default * CUDA case * Topk * Refactor Option references for model paths (#1347) * refactor: use Option refs in model path helpers * Format * Add a script for server benchmarking (#1355) * Serde alias * Fix * Update for tie_word_embeddings * Print running/waiting * 30 users * Update num_users * Update dummy paged attn * Optimized Metal qmv_fast path (#1356) * Compile with lto * Tweak profiles * Fix topk * Penalties * Add logits processor, clippy fixes * Fix chat port
1 parent c89aa3c commit a97de2b

File tree

17 files changed

+4244
-389
lines changed

17 files changed

+4244
-389
lines changed

mistralrs-core/src/pipeline/mod.rs

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -300,19 +300,6 @@ impl ForwardInputsResult {
300300
}),
301301
}
302302
}
303-
304-
fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
305-
match self {
306-
Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
307-
logits: logits.to_device(device)?,
308-
}),
309-
Self::RawLogits { logits } => Ok(Self::RawLogits {
310-
logits: logits.to_device(device)?,
311-
}),
312-
Self::Image { .. } => Ok(self.clone()),
313-
Self::Speech { .. } => Ok(self.clone()),
314-
}
315-
}
316303
}
317304

318305
#[async_trait::async_trait]
@@ -445,11 +432,8 @@ pub trait Pipeline:
445432
let start = Instant::now();
446433
let logits = logits
447434
.into_iter()
448-
.map(|l| {
449-
l.expect("Did not get any inputs. This is shocking.")
450-
.to_device(&Device::Cpu)
451-
})
452-
.collect::<candle_core::Result<Vec<_>>>()?;
435+
.map(|l| l.expect("Did not get any inputs. This is shocking."))
436+
.collect::<Vec<_>>();
453437

454438
match &logits[0] {
455439
ForwardInputsResult::RawLogits { .. } => unreachable!(),
@@ -596,11 +580,8 @@ pub trait Pipeline:
596580
let start = Instant::now();
597581
let logits = logits
598582
.into_iter()
599-
.map(|l| {
600-
l.expect("Did not get any inputs. This is shocking.")
601-
.to_device(&Device::Cpu)
602-
})
603-
.collect::<candle_core::Result<Vec<_>>>()?;
583+
.map(|l| l.expect("Did not get any inputs. This is shocking."))
584+
.collect::<Vec<_>>();
604585

605586
match &logits[0] {
606587
ForwardInputsResult::RawLogits { .. } => unreachable!(),

mistralrs-core/src/sampler.rs

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use std::{
55
sync::{Arc, Mutex},
66
};
77

8-
use candle_core::{Device, Error, Result, Tensor, D};
8+
use candle_core::{DType, Device, Error, Result, Tensor, D};
9+
use mistralrs_quant::{CumSumOp, SortOp};
910
#[cfg(feature = "pyo3_macros")]
1011
use pyo3::pyclass;
1112

@@ -329,6 +330,160 @@ impl Sampler {
329330
})
330331
}
331332

333+
#[allow(unused)]
334+
fn sample_fast(
335+
&self,
336+
logits: Tensor,
337+
context: &[u32],
338+
return_logprobs: bool,
339+
top_k: i64,
340+
top_p: f64,
341+
min_p: f64,
342+
) -> Result<Logprobs> {
343+
let mut probs = logits.to_dtype(DType::F32)?;
344+
345+
for processor in &self.logits_processors {
346+
probs = processor.apply(&probs, context)?;
347+
}
348+
349+
let context = Tensor::new(context, logits.device())?;
350+
let mut counts = logits.zeros_like()?;
351+
counts = counts.scatter_add(
352+
&context,
353+
&context.ones_like()?.to_dtype(counts.dtype())?,
354+
D::Minus1,
355+
)?;
356+
357+
let presence = counts
358+
.gt(0.)?
359+
.where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
360+
361+
match self.frequency_penalty {
362+
Some(freq_penalty) if freq_penalty != 0. => {
363+
probs = (probs - (freq_penalty as f64 * counts)?)?;
364+
}
365+
_ => (),
366+
}
367+
368+
match self.presence_penalty {
369+
Some(pres_penalty) if pres_penalty != 0. => {
370+
probs = (probs - (pres_penalty as f64 * presence)?)?;
371+
}
372+
_ => (),
373+
}
374+
375+
probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
376+
377+
// Top-K
378+
if top_k > 0 {
379+
let sorted_values = probs.fast_sort_asc(D::Minus1)?;
380+
let topk_values = sorted_values.narrow(
381+
D::Minus1,
382+
sorted_values.dim(D::Minus1)? - top_k as usize,
383+
top_k as usize,
384+
)?;
385+
386+
// select the kth largest value as threshold
387+
let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
388+
let mask_topk = probs.broadcast_ge(&threshold)?;
389+
probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
390+
}
391+
392+
// Top-P (nucleus)
393+
if top_p > 0.0 && top_p < 1.0 {
394+
let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
395+
396+
let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
397+
398+
let mask_topp = cumsum.le(top_p)?;
399+
400+
let masked_sorted =
401+
mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
402+
403+
let threshold = masked_sorted.max(D::Minus1)?;
404+
let threshold = threshold.unsqueeze(D::Minus1)?;
405+
let mask_full = probs.broadcast_ge(&threshold)?;
406+
probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
407+
}
408+
409+
// Min-P
410+
if min_p > 0.0 && min_p < 1.0 {
411+
let max_vals = probs.max(D::Minus1)?;
412+
let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
413+
let mask_minp = probs.broadcast_gt(&threshold_min)?;
414+
probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
415+
}
416+
417+
let next_token = probs.argmax(D::Minus1)?.to_scalar::<u32>()?;
418+
419+
// Extract the top‑n log‑probs if the caller asked for them.
420+
let (top_logprobs, logprob) = if return_logprobs {
421+
let k = self.top_n_logprobs;
422+
423+
let sorted_values = probs.fast_sort_asc(D::Minus1)?;
424+
let topk_values = sorted_values
425+
.narrow(
426+
D::Minus1,
427+
sorted_values.dim(D::Minus1)? - top_k as usize,
428+
top_k as usize,
429+
)?
430+
.to_vec1::<f32>()?;
431+
432+
let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
433+
let topk_idxs = sorted_idxs
434+
.narrow(
435+
D::Minus1,
436+
sorted_values.dim(D::Minus1)? - top_k as usize,
437+
top_k as usize,
438+
)?
439+
.to_vec1::<u32>()?;
440+
441+
let mut result = Vec::with_capacity(k);
442+
if let Some(tokenizer) = &self.tokenizer {
443+
for (prob, token) in topk_values.iter().zip(topk_idxs) {
444+
let decoded = tokenizer
445+
.decode(&[token], false)
446+
.map_err(|e| Error::Msg(e.to_string()))?;
447+
result.push(TopLogprob {
448+
token,
449+
logprob: prob.log(10.0),
450+
bytes: Some(decoded),
451+
});
452+
}
453+
} else {
454+
for (prob, token) in topk_values.iter().zip(topk_idxs) {
455+
result.push(TopLogprob {
456+
token,
457+
logprob: prob.log(10.0),
458+
bytes: None,
459+
});
460+
}
461+
}
462+
463+
let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
464+
465+
(Some(result), logprob)
466+
} else {
467+
(None, 1.)
468+
};
469+
470+
let bytes = if let Some(tokenizer) = &self.tokenizer {
471+
Some(
472+
tokenizer
473+
.decode(&[next_token], false)
474+
.map_err(|x| Error::Msg(x.to_string()))?,
475+
)
476+
} else {
477+
None
478+
};
479+
480+
Ok(Logprobs {
481+
token: next_token,
482+
logprob,
483+
top_logprobs,
484+
bytes,
485+
})
486+
}
332487
fn sample_speculative_top_kp_min_p(
333488
&self,
334489
logits: Tensor,
@@ -623,6 +778,7 @@ impl Sampler {
623778
Ok(())
624779
}
625780

781+
#[allow(unused)]
626782
/// Sample the provided tokens.
627783
///
628784
/// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
@@ -635,6 +791,16 @@ impl Sampler {
635791
rng: Arc<Mutex<Isaac64Rng>>,
636792
sample_speculative: bool,
637793
) -> Result<Logprobs> {
794+
#[cfg(feature = "metal")]
795+
return self.sample_fast(
796+
logits,
797+
context,
798+
return_logprobs,
799+
self.top_k,
800+
self.top_p,
801+
self.min_p,
802+
);
803+
638804
let logits = logits.to_vec1()?;
639805
let mut logits = self.apply_penalties(logits, context)?;
640806
for processor in &self.logits_processors {

mistralrs-quant/build.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,23 @@ fn main() -> Result<(), String> {
155155
use std::process::Command;
156156
use std::{env, str};
157157

158-
const METAL_SOURCES: [&str; 5] = [
158+
const METAL_SOURCES: [&str; 8] = [
159159
"bitwise",
160160
"blockwise_fp8",
161161
"bnb_dequantize",
162162
"hqq_dequantize",
163163
"quantized",
164+
"scan",
165+
"sort",
166+
"copy",
164167
];
168+
const HEADER_SOURCES: [&str; 5] = ["utils", "bf16", "scan_impl", "sort_impl", "copy_impl"];
165169
for src in METAL_SOURCES {
166170
println!("cargo::rerun-if-changed=src/metal_kernels/{src}.metal");
167171
}
168-
println!("cargo::rerun-if-changed=src/metal_kernels/utils.metal");
172+
for src in HEADER_SOURCES {
173+
println!("cargo::rerun-if-changed=src/metal_kernels/{src}.metal");
174+
}
169175
println!("cargo::rerun-if-changed=build.rs");
170176

171177
enum Platform {
@@ -203,7 +209,9 @@ fn main() -> Result<(), String> {
203209
for metal_file in METAL_SOURCES {
204210
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
205211
}
206-
compile_air_cmd.arg(sources.join("utils.metal"));
212+
for metal_file in HEADER_SOURCES {
213+
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
214+
}
207215
compile_air_cmd
208216
.spawn()
209217
.expect("Failed to compile air")
@@ -247,7 +255,9 @@ fn main() -> Result<(), String> {
247255
for metal_file in METAL_SOURCES {
248256
compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air")));
249257
}
250-
compile_metallib_cmd.arg(out_dir.join("utils.air"));
258+
for metal_file in HEADER_SOURCES {
259+
compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air")));
260+
}
251261

252262
let mut child = compile_metallib_cmd
253263
.spawn()

mistralrs-quant/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub use lora::{
5858
};
5959
pub use unquantized::UnquantLinear;
6060
pub use utils::isq::apply_immediate_isq;
61-
pub use utils::{log, BitWiseOp, LeftshiftOp, NonZeroOp, UQFF_QUANT_TYPE_OFFSET};
61+
pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
6262

6363
use candle_nn::{Linear, Module};
6464
use serde::{Deserialize, Deserializer, Serialize};

0 commit comments

Comments
 (0)