-
Notifications
You must be signed in to change notification settings - Fork 431
Fast sampler #1327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast sampler #1327
Conversation
WalkthroughThis update introduces extensive improvements to the quantization and Metal backend infrastructure, including new GPU kernels for scan, sort, and copy operations, and exposes new fast tensor-based sorting and cumulative sum operations at the Rust level. The model initialization logic is refactored to propagate quantization configuration to language model head layers, and serde aliasing is added for quantization fields across numerous model configs. Scheduler interfaces are updated to integrate interval logging of running and waiting sequences. Additional code style and refactoring changes improve clarity, consistency, and maintainability throughout the codebase. Changes
Sequence Diagram(s)sequenceDiagram
participant Scheduler
participant Logger
participant Engine
Engine->>Scheduler: schedule(&Logger)
Scheduler->>Logger: set_num_running(running_count)
Scheduler->>Logger: set_num_waiting(waiting_count)
Scheduler-->>Engine: SchedulerOutput
sequenceDiagram
participant Tensor
participant MetalBackend
participant User
User->>Tensor: .fast_sort_asc(axis)
Tensor->>MetalBackend: launch sort kernel
MetalBackend-->>Tensor: sorted tensor
Tensor-->>User: sorted tensor
sequenceDiagram
participant Tensor
participant MetalBackend
participant User
User->>Tensor: .fast_cumsum(axis)
Tensor->>MetalBackend: launch scan kernel
MetalBackend-->>Tensor: cumsum tensor
Tensor-->>User: cumsum tensor
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Code Metrics Report=============================================================================== Language Files Lines Code Comments Blanks =============================================================================== C Header 3 62 53 0 9 Dockerfile 1 41 22 10 9 JSON 12 107 106 0 1 Makefile 1 6 5 0 1 Python 86 4042 3413 156 473 Shell 1 63 26 18 19 Plain Text 3 3723 0 2413 1310 TOML 19 565 518 6 41 YAML 2 21 19 2 0 ------------------------------------------------------------------------------- Jupyter Notebooks 3 0 0 0 0 |- Markdown 2 77 32 31 14 |- Python 2 205 178 1 26 (Total) 282 210 32 40 ------------------------------------------------------------------------------- Markdown 55 5012 0 3822 1190 |- BASH 8 104 101 0 3 |- JSON 1 12 12 0 0 |- Python 7 121 109 0 12 |- Rust 22 757 634 1 122 |- TOML 2 75 63 0 12 (Total) 6081 919 3823 1339 ------------------------------------------------------------------------------- Rust 378 126908 113288 2593 11027 |- Markdown 171 2145 29 1913 203 (Total) 129053 113317 4506 11230 =============================================================================== Total 564 140550 117450 9020 14080 =============================================================================== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
🧹 Nitpick comments (11)
mistralrs-server/src/interactive_mode.rs (1)
242-245
: Start TTFT timer beforesend
for accuracy (optional)The timer is initialised after the request is sent.
If queueing on the Tokio channel becomes significant, TTFT will be underestimated.
Consider movinglet start_ttft = Instant::now();
immediately beforesender.send(req).await?;
.mistralrs-quant/src/metal_kernels/mod.rs (1)
1176-1199
: Unused variable_bm
and potential overflow in grid calc
let _bm = 32;
is never used – remove or employ it.Additionally, multiplying
tmp_grid_dims.{width|height}
bystride_blocks
can overflowu64
for very large tensors. Usechecked_mul
and fall back to the alternate dimension on overflow.mistralrs-core/src/sampler.rs (1)
353-366
: Top-K threshold broadcasting may mis-align for batched logits
unsqueeze(0)
produces a shape[1, vocab]
.
For batched inputs[batch, vocab]
this works, but for a higher-rank tensor the broadcast is ambiguous.
Useunsqueeze(D::Minus1)
(or keep the original dim count viaexpand
) to guarantee alignment with the last dimension irrespective of rank.mistralrs-quant/src/utils/ops.rs (2)
1011-1020
: Wrong error message and unreachable match armThe fallback arm returns
Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum"))
even for many other dtypes, making debugging harder.Replace with the actual
s1.dtype()
:- _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")), + _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "cumsum")),
904-909
: Naming & visibility
CumSum
is private but the traitCumSumOp
relies on users callingfast_cumsum[_config]
.
Consider making the structpub(crate)
and prefixing internal helpers with_
to clarify intent.mistralrs-quant/src/metal_kernels/bf16.metal (1)
10-13
: Duplicate typedef ofbfloat16_t
Lines 10 and 12 both declare
typedef bfloat bfloat16_t;
.
Remove the second to avoid “redefinition” warnings.mistralrs-quant/src/metal_kernels/scan.metal (1)
52-64
: Macro explosion may hit Metal compiler limits.The 𝟠×𝟡 explicit
instantiate_scan_helper(...)
calls generate 70+ specialized kernels per op (sum/prod/…) and per layout.
Several Apple-silicon tool-chains start failing once the number of functions in a single.metal
file approaches ~1 000 because the internal IR hits the 64 Ki symbol table limit.
Consider splitting large groups into separate translation units or gate‐instantiate with#ifdef
s that match the actual runtime usage set to keep compile time and binary size reasonable.mistralrs-core/src/prefix_cacher.rs (1)
110-152
: Potentially expensive full-vector clone in hot path.
seq.normal_cache().to_vec()
andseq.image_hashes().map(|v| v.to_vec())
allocate and copy on every call even when the cache entry for the sequence already exists.
You can avoid the allocation for updates by checkingif nb.cache.is_none()
before cloning:let (data, img_hashes) = if nb.cache.is_none() { (seq.normal_cache().to_vec(), seq.image_hashes().map(|v| v.to_vec())) } else { (Vec::new(), None) // nothing needed – will be overwritten below };mistralrs-quant/src/metal_kernels/utils.metal (2)
12-16
: Duplicate typedef – remove the extra declaration.
typedef bfloat bfloat16_t;
appears twice back-to-back which may trigger a
“redefinition” warning on stricter Metal compilers.-typedef bfloat bfloat16_t; -typedef bfloat bfloat16_t; +typedef bfloat bfloat16_t; // single definition is enough
1151-1163
: Logical functors returnT
but evaluate tobool
.
LogicalAnd
/LogicalOr
computex && y
/x || y
yet return a value of
typeT
. For non-booleanT
(e.g.float
,int
,half
) this relies on
implicit conversion frombool
and silently narrows the result to0
or1
.
Returningbool
clarifies semantics and avoids accidental use in arithmetic
code where a full-precisionT
is expected.-template <typename T> T operator()(T x, T y) { return x && y; } +template <typename T> bool operator()(T x, T y) { return x && y; }Apply similarly to
LogicalOr
.mistralrs-quant/src/metal_kernels/scan_impl.metal (1)
121-140
:CumLogaddexp
needs numerical stabilisation
LogAddExp
for large negative inputs can underflow to-∞
, breaking the
scan when the tensor contains a wide dynamic range. Consider the classic
stable implementation:U operator()(U a, T b) { U m = max(a, static_cast<U>(b)); return m + log(exp(a - m) + exp(static_cast<U>(b) - m)); }
If
LogAddExp{}
already implements this, document it; otherwise replace the
call with a numerically stable version.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
Cargo.lock
is excluded by!**/*.lock
📒 Files selected for processing (16)
mistralrs-core/Cargo.toml
(1 hunks)mistralrs-core/src/pipeline/mod.rs
(1 hunks)mistralrs-core/src/prefix_cacher.rs
(2 hunks)mistralrs-core/src/sampler.rs
(4 hunks)mistralrs-quant/build.rs
(3 hunks)mistralrs-quant/src/lib.rs
(1 hunks)mistralrs-quant/src/metal_kernels/bf16.metal
(1 hunks)mistralrs-quant/src/metal_kernels/mod.rs
(2 hunks)mistralrs-quant/src/metal_kernels/quantized.metal
(0 hunks)mistralrs-quant/src/metal_kernels/scan.metal
(1 hunks)mistralrs-quant/src/metal_kernels/scan_impl.metal
(1 hunks)mistralrs-quant/src/metal_kernels/utils.metal
(2 hunks)mistralrs-quant/src/metal_kernels/utils.rs
(1 hunks)mistralrs-quant/src/utils/mod.rs
(1 hunks)mistralrs-quant/src/utils/ops.rs
(2 hunks)mistralrs-server/src/interactive_mode.rs
(6 hunks)
💤 Files with no reviewable changes (1)
- mistralrs-quant/src/metal_kernels/quantized.metal
⏰ Context from checks skipped due to timeout of 90000ms (8)
- GitHub Check: Test Suite (macOS-latest, stable)
- GitHub Check: Docs
- GitHub Check: Test Suite (ubuntu-latest, stable)
- GitHub Check: Test Suite (windows-latest, stable)
- GitHub Check: Check (macOS-latest, stable)
- GitHub Check: Clippy
- GitHub Check: Check (windows-latest, stable)
- GitHub Check: Check (ubuntu-latest, stable)
🔇 Additional comments (7)
mistralrs-core/Cargo.toml (1)
92-94
: Good additions for the new radix trie implementation.The three new dependencies added (
hashbrown
,parking_lot
, andahash
) are appropriate choices for implementing the concurrent radix trie with LRU eviction mechanism. These libraries are well-maintained, provide good performance characteristics, and are commonly used together for high-performance concurrent data structures.mistralrs-quant/src/utils/mod.rs (1)
8-8
: Appropriate public export for the new CumSumOp trait.The
CumSumOp
trait is correctly added to the list of publicly exposed operations from the module. This change aligns with the module's existing pattern of re-exporting operation traits.mistralrs-quant/src/lib.rs (1)
59-59
: Consistent crate-level re-export of the new operation.The
CumSumOp
trait is properly added to the crate-level public exports alongside other operation traits, making it accessible to users of the library without requiring them to import from specific submodules.mistralrs-core/src/pipeline/mod.rs (1)
423-427
:❓ Verification inconclusive
Simplified logits processing by removing explicit CPU device conversion.
The code has been simplified by removing the
.to_device(&Device::Cpu)
call and the associated error handling, which may improve performance by avoiding unnecessary device transfers. However, there's an inconsistency since the similar code in thePagedAttention
branch (lines 564-570) still performs this conversion.Please verify that operations downstream don't require the tensors to be on the CPU device. The inconsistency between this branch and the
PagedAttention
branch suggests there might be specific scenarios where device placement matters.
🏁 Script executed:
#!/bin/bash # Find where the logits are used in sample_causal_gen to check if they require CPU device rg -p "fn sample_causal_gen.*?\{\s*.*?\}" --multiline mistralrs-core/src/Length of output: 75
I’ll inspect the surrounding code to locate both branches and any remaining
.to_device(&Device::Cpu)
calls.#!/bin/bash set -euxo pipefail # 1. Show the logits processing snippet (around lines 423–427) sed -n '350,450p' mistralrs-core/src/pipeline/mod.rs # 2. Show the PagedAttention branch (around lines 564–570) sed -n '550,650p' mistralrs-core/src/pipeline/mod.rs # 3. Find all remaining `.to_device(&Device::Cpu)` occurrences rg -n ".to_device\s*\(\s*&Device::Cpu\s*\)" -n mistralrs-core/src/pipeline/mod.rs
To pinpoint the discrepancy and verify if CPU placement is still used elsewhere, let’s dump both branches and search for any remaining
.to_device(&Device::Cpu)
calls:#!/bin/bash set -euxo pipefail # 1. Show logits processing around lines 420–430 sed -n '410,440p' mistralrs-core/src/pipeline/mod.rs # 2. Show PagedAttention branch around lines 560–580 sed -n '550,590p' mistralrs-core/src/pipeline/mod.rs # 3. List all occurrences of to_device(&Device::Cpu) rg -n "to_device\s*\(\s*&Device::Cpu\s*\)" mistralrs-core/src/pipeline/mod.rs || true
🏁 Script executed:
#!/usr/bin/env bash set -euxo pipefail # 1. Show logits processing around lines 420–430 sed -n '410,440p' mistralrs-core/src/pipeline/mod.rs # 2. Show PagedAttention branch around lines 560–580 sed -n '550,590p' mistralrs-core/src/pipeline/mod.rs # 3. List all occurrences of to_device(&Device::Cpu) rg -n "to_device\s*\(\s*&Device::Cpu\s*\)" mistralrs-core/src/pipeline/mod.rs || trueLength of output: 3679
Inconsistent
.to_device(&Device::Cpu)
usage in logits processingThe normal CausalGeneration path (around lines 423–427) now drops the
.to_device(&Device::Cpu)
call, but the PagedAttention branch (line 568) still performs it—and there are other conversions at lines 384 and 540. Please verify whether downstream routines (e.g.sample_causal_gen
) require CPU-resident tensors. Then unify both branches:• Remove
.to_device(&Device::Cpu)
in the PagedAttention branch (and at lines 384/540) if GPU tensors are acceptable
• Otherwise, re-introduce the CPU conversion in the simplified branchFiles/locations to review:
- mistralrs-core/src/pipeline/mod.rs: lines 423–427
- mistralrs-core/src/pipeline/mod.rs: line 568
- mistralrs-core/src/pipeline/mod.rs: lines 384, 540
mistralrs-quant/build.rs (1)
170-174
: Compiling header files may duplicate symbolsHeader sources are now passed to the Metal compiler as standalone translation units:
for src in HEADER_SOURCES { println!("cargo:rerun-if-changed=src/metal_kernels/{src}.metal"); }and later:
for metal_file in HEADER_SOURCES { compile_air_cmd.arg(sources.join(format!("{metal_file}.metal"))); }If these headers define templated kernels or helper functions without
static
linkage
they may be emitted twice, producing duplicate-symbol errors at link time.Consider adding them with
-I
/-include
instead of compiling them directly.mistralrs-quant/src/metal_kernels/bf16.metal (1)
47-56
: Heavy use of C++17 type-traits may not compile under MetalMetal shading language roughly aligns with C++14; helpers such as
is_convertible_v
,is_same_v
and variable templates are C++17.
Unless you compile with-std=c++17-metal
(not universally available), this section will fail.If wide portability is desired, fall back to
std::is_same<…, bool>::value
+enable_if_t
.mistralrs-core/src/prefix_cacher.rs (1)
155-182
:evict_caches
can dereference aNone
head – guard early.If
current_on_device > n_on_device
gets out of sync (e.g. manualcurrent_on_device
mutation or future parallelisation),lru_head
might beNone
, leading to an infinite loop (while
condition never falsifies andhead
isNone
).
Add an assert or early-return before the loop:if self.lru_head.is_none() { debug_assert_eq!(self.current_on_device, 0, "Inconsistent LRU list"); break; }
if let Some(ttft) = first_token_duration { | ||
println!("Time to first token: {:.2?}s", ttft.as_secs_f32()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid format string – code will not compile
println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
{:?}
is the Debug formatter, which cannot be combined with a precision specifier (.2
).
The compiler will emit:
format argument must not be followed by `?`
Replace the debug specifier with the normal Display formatter:
-println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
+println!("Time to first token: {:.2}s", ttft.as_secs_f32());
Apply the same change in the vision branch (lines 550-552).
let size = shape[axis]; | ||
let stride = strides[size]; | ||
let _bm = 32; | ||
let bn = 32; | ||
let stride_blocks = stride.div_ceil(bn); | ||
|
||
encoder.set_buffer(0, Some(xs), xs_offset as u64); | ||
encoder.set_buffer(1, Some(output), 0); | ||
|
||
encoder.set_bytes( | ||
2, | ||
std::mem::size_of::<usize>() as u64, | ||
&size as *const usize as *const _, | ||
); | ||
encoder.set_bytes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Array-index bug: using size
instead of axis
let stride = strides[size];
size
is the length of the scan dimension, not its index.
Accessing strides[size]
will:
- Panic with
index out of bounds
whensize ≥ strides.len()
. - Read an unrelated stride when the scan dimension length happens to be a valid index.
Fix:
- let stride = strides[size];
+ let stride = strides[axis];
Add a unit test with a non-trivial stride to catch this class of error.
if divisor > 1 { | ||
if grid_x % divisor == 0 { | ||
grid_x /= divisor; | ||
divisor = 1; | ||
} else if grid_y % divisor == 0 { | ||
grid_y /= divisor; | ||
divisor = 1; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remaining divisor not validated
After the loop divisor
may still be > 1
, meaning the requested factorisation failed
silently. This produces wrong grid sizes and hard-to-trace GPU crashes.
if divisor > 1 {
if grid_x % divisor == 0 {
grid_x /= divisor;
divisor = 1;
} else if grid_y % divisor == 0 {
grid_y /= divisor;
divisor = 1;
}
}
+
+ // Ensure the divisor was fully factored out
+ assert_eq!(divisor, 1, "Unable to factor divisor from shape");
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if divisor > 1 { | |
if grid_x % divisor == 0 { | |
grid_x /= divisor; | |
divisor = 1; | |
} else if grid_y % divisor == 0 { | |
grid_y /= divisor; | |
divisor = 1; | |
} | |
} | |
if divisor > 1 { | |
if grid_x % divisor == 0 { | |
grid_x /= divisor; | |
divisor = 1; | |
} else if grid_y % divisor == 0 { | |
grid_y /= divisor; | |
divisor = 1; | |
} | |
} | |
// Ensure the divisor was fully factored out | |
assert_eq!(divisor, 1, "Unable to factor divisor from shape"); |
} | ||
|
||
let type_name = match ty { | ||
DType::F32 => "float32", | ||
DType::BF16 => "bfloat16", | ||
DType::F16 => "float16", | ||
DType::U8 => "uint8", | ||
DType::I16 => "int16", | ||
DType::I32 => "int32", | ||
DType::I64 => "int64", | ||
other => { | ||
return Err(MetalKernelError::DTypeMismatch { | ||
expected: vec![ | ||
DType::F32, | ||
DType::F16, | ||
DType::BF16, | ||
DType::U8, | ||
DType::I16, | ||
DType::I32, | ||
DType::I64, | ||
], | ||
got: other, | ||
}) | ||
} | ||
}; | ||
name.push_str(&format!("{type_name}_{type_name}")); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
LogAddExp
on integer dtypes is undefined
ScanType::LogAddExp
is appended without restricting to floating-point types.
Metal kernels will likely perform illegal operations on u8/i16/i32/i64
.
Guard the match:
if matches!(op, ScanType::LogAddExp) && !ty.is_float() {
return Err(MetalKernelError::DTypeMismatch { … });
}
or restrict expected
accordingly.
static constexpr constant U init = static_cast<U>(1.0f); | ||
|
||
template <typename T> U operator()(U a, T b) { return a * b; } | ||
|
||
U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
CumProd::init
uses a float literal – breaks for integral/half types
static_cast<U>(1.0f)
is safe only when U
is a floating-point type.
For int
, uint
, char
, half
, etc., this introduces an implicit
float-to-int conversion which is (a) UB for bool
, (b) a silent narrowing
conversion for integers, and (c) breaks constexpr evaluation on half types
prior to macOS 14.
- static constexpr constant U init = static_cast<U>(1.0f);
+ static constexpr constant U init = static_cast<U>(1);
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
static constexpr constant U init = static_cast<U>(1.0f); | |
template <typename T> U operator()(U a, T b) { return a * b; } | |
U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); } | |
static constexpr constant U init = static_cast<U>(1); | |
template <typename T> U operator()(U a, T b) { return a * b; } | |
U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); } |
template <typename U> struct CumMax { | ||
static constexpr constant U init = Limits<U>::min; | ||
|
||
template <typename T> U operator()(U a, T b) { return (a >= b) ? a : b; } | ||
|
||
U simd_scan(U x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial value for CumMax
is wrong for floating point – should be lowest()
not min()
Limits<U>::min
maps to the smallest positive value for floats
(FLT_MIN
), not the most negative. Using it makes every max
scan return
positive numbers even when the input contains negatives.
- static constexpr constant U init = Limits<U>::min;
+ static constexpr constant U init = Limits<U>::lowest; // requires Limits helper, or use numeric_limits<U>::lowest()
Do the same for CumMin
with Limits<U>::max
→ Limits<U>::highest
.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
template <typename U> struct CumMax { | |
static constexpr constant U init = Limits<U>::min; | |
template <typename T> U operator()(U a, T b) { return (a >= b) ? a : b; } | |
U simd_scan(U x) { | |
template <typename U> struct CumMax { | |
static constexpr constant U init = Limits<U>::lowest; // requires Limits helper, or use std::numeric_limits<U>::lowest() | |
template <typename T> U operator()(U a, T b) { return (a >= b) ? a : b; } | |
U simd_scan(U x) { | |
// ... |
if ((offset + N_READS) < axis_size) { | ||
load_unsafe<T, U, N_READS, reverse>(values, | ||
in + axis_size - offset - N_READS); | ||
} else { | ||
load_safe<T, U, N_READS, reverse>(values, | ||
in + axis_size - offset - N_READS, | ||
offset, axis_size, Op::init); | ||
} | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible pointer underflow for reverse loads
When offset >= axis_size
, the expression
in + axis_size - offset - N_READS
wraps around (unsigned underflow) and
produces a huge positive address → undefined behaviour / GPU fault.
The guard if ((offset + N_READS) < axis_size)
is not sufficient once
offset == axis_size - N_READS
because the else
branch still applies the
same arithmetic.
Fix by clamping before pointer math:
- load_safe<T, U, N_READS, reverse>(values,
- in + axis_size - offset - N_READS,
+ size_t base = (offset + N_READS > axis_size)
+ ? 0
+ : axis_size - offset - N_READS;
+ load_safe<T, U, N_READS, reverse>(values,
+ in + base,
offset, axis_size, Op::init);
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if ((offset + N_READS) < axis_size) { | |
load_unsafe<T, U, N_READS, reverse>(values, | |
in + axis_size - offset - N_READS); | |
} else { | |
load_safe<T, U, N_READS, reverse>(values, | |
in + axis_size - offset - N_READS, | |
offset, axis_size, Op::init); | |
} | |
} else { | |
if ((offset + N_READS) < axis_size) { | |
load_unsafe<T, U, N_READS, reverse>(values, | |
in + axis_size - offset - N_READS); | |
} else { | |
size_t base = (offset + N_READS > axis_size) | |
? 0 | |
: axis_size - offset - N_READS; | |
load_safe<T, U, N_READS, reverse>(values, | |
in + base, | |
offset, axis_size, Op::init); | |
} | |
} else { |
threadgroup U simdgroup_sums[32]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simdgroup_sums[32]
can overflow when lsize.x
> 1024
simd_groups = lsize.x / 32
. Metal allows up to 1024 threads per
thread-group on recent GPUs, but future devices (and Apple Silicon’s
M-series in compute mode) may expose 1536.
If lsize.x == 1536
, simd_groups == 48
→ out-of-bounds write.
- threadgroup U simdgroup_sums[32];
+ threadgroup U simdgroup_sums[simd_groups];
(or threadgroup_array<U, MAX_SIMD_GROUPS>
with a static assert).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
threadgroup U simdgroup_sums[32]; | |
threadgroup U simdgroup_sums[simd_groups]; |
constexpr int BN = 32; | ||
constexpr int BN_pad = 32 + 16 / sizeof(U); | ||
constexpr int n_simds = BN / N_READS; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
BN_pad
mis-aligns shared memory for 64-bit types
BN_pad = 32 + 16 / sizeof(U)
yields 34 for 64-bit data, which breaks
128-bit alignment required for double
/ulong
vectorised accesses and
causes bank conflicts on A-series GPUs.
- constexpr int BN_pad = 32 + 16 / sizeof(U);
+ constexpr int BN_pad = 32 + (16 / sizeof(U)); // keep, *then*
+ static_assert((BN_pad % (16 / sizeof(U))) == 0, "BN_pad must align to 128-bit");
Alternatively compute round_up(BN, 16/sizeof(U))
.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
constexpr int BN = 32; | |
constexpr int BN_pad = 32 + 16 / sizeof(U); | |
constexpr int n_simds = BN / N_READS; | |
constexpr int BN = 32; | |
- constexpr int BN_pad = 32 + 16 / sizeof(U); | |
+ constexpr int BN_pad = 32 + (16 / sizeof(U)); // keep, *then* | |
+ static_assert((BN_pad % (16 / sizeof(U))) == 0, "BN_pad must align to 128-bit"); | |
constexpr int n_simds = BN / N_READS; |
089b617
to
43dedaa
Compare
* refactor: use Option refs in model path helpers * Format
* Serde alias * Fix * Update for tie_word_embeddings * Print running/waiting * 30 users * Update num_users * Update dummy paged attn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
♻️ Duplicate comments (13)
mistralrs-core/src/sampler.rs (3)
417-418
:⚠️ Potential issue
sample_fast
drops probabilistic behaviourThe implementation always takes
argmax
of the filtered distribution; no weighted/multinomial sampling is performed. This makes the sampling deterministic even with non-zero temperature, which is inconsistent with the standard sampling behavior where temperature controls the randomness.Consider reinstating true multinomial sampling on the masked distribution:
- let next_token = probs.argmax(D::Minus1)?.to_scalar::<u32>()?; + // Sample from the distribution if temperature is set, otherwise use argmax + let next_token = if self.temperature.is_some() { + // Convert to CPU for sampling with WeightedIndex + let probs_vec: Vec<f32> = probs.to_vec1()?; + let distr = WeightedIndex::new(&probs_vec)?; + let mut rng_lock = rng.lock().expect("could not lock rng mutex"); + distr.sample(&mut *rng_lock) as u32 + } else { + probs.argmax(D::Minus1)?.to_scalar::<u32>()? + };
464-468
:⚠️ Potential issueIncorrect log probability calculation
The implementation has two issues with log probabilities:
- When
return_logprobs
is false, it hardcodes logprob to 1.0 (line 467), which gives log(prob) = 0, breaking accuracy.- When
return_logprobs
is true, it doesn't compute the actual log probability of the selected token.This will cause incorrect log probability values to be returned, affecting applications that rely on accurate probability tracking.
Apply this fix to correctly compute the log probability of the selected token:
- let logprob = result.last().map(|res| res.logprob).unwrap_or(1.); + // Find the actual logprob of the selected token + let logprob = result + .iter() + .find(|res| res.token == next_token) + .map(|res| res.logprob) + .unwrap_or_else(|| { + // If token not in top-k, compute its logprob directly + probs.i(next_token as i64)?.to_scalar::<f32>()?.log10() + }); - } else { - (None, 1.) + } else { + // Always compute actual logprob even when not returning top logprobs + let prob = probs.i(next_token as i64)?.to_scalar::<f32>()?; + (None, prob.log10())
389-390
: 🛠️ Refactor suggestionMissing probability re-normalization after masking
The implementation masks probabilities below thresholds to zero but doesn't re-normalize the resulting distribution. This is fine for argmax but would be incorrect for multinomial sampling, as sampling requires a properly normalized probability distribution.
Add re-normalization after each masking operation:
probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?; + // Re-normalize probabilities to sum to 1 + let probs_sum = probs.sum_all(true)?; + probs = (probs / probs_sum)?; // Similar changes needed after Top-P and Min-P maskingAlso applies to: 406-407, 414-415
mistralrs-quant/src/metal_kernels/utils.rs (1)
59-67
:⚠️ Potential issueRemaining divisor not validated
After the loop
divisor
may still be> 1
, meaning the requested factorization failed silently. This produces wrong grid sizes and hard-to-trace GPU crashes.if divisor > 1 { if grid_x % divisor == 0 { grid_x /= divisor; divisor = 1; } else if grid_y % divisor == 0 { grid_y /= divisor; divisor = 1; } } + + // Ensure the divisor was fully factored out + assert_eq!(divisor, 1, "Unable to factor divisor from shape");mistralrs-quant/src/metal_kernels/mod.rs (2)
1113-1120
:LogAddExp
accepts integer dtypes – undefined behaviour
ScanType::LogAddExp
is currently allowed for everyDType
in thematch
below.
The Metal implementation relies on floating-point math; passingu8/i32/…
will generate incorrect results or trigger invalid-op exceptions on GPU.Guard the call before building the kernel name:
if matches!(op, ScanType::LogAddExp) && !ty.is_float() { return Err(MetalKernelError::DTypeMismatch { expected: vec![DType::F32, DType::F16, DType::BF16], got: ty, }); }(or restrict the
type_name
match arm to float dtypes only).
[ suggest_essential_refactor ]Also applies to: 1121-1142
1189-1194
:⚠️ Potential issueStride-lookup still indexes with length instead of axis
stride = strides[size];
repeats the off-by-one/ OOB bug previously raised –size
is the length of the scan-axis, not its position.
This will (a) read the wrong stride for most tensors and (b) panic whensize ≥ strides.len()
.- let stride = strides[size]; + let stride = strides[axis];Please add a regression test with a non-unit stride to prevent future re-introductions.
mistralrs-quant/src/utils/ops.rs (1)
1322-1336
:⚠️ Potential issueCPU
CumSum
still assumes the scan axis is the innermost – same bug as flagged earlierThe inner loops step through the buffer with
let base = block * axis_len;
and then indexinput[base + j]
.
This is only correct when the axis has a stride of 1 (i.e. it is the last / innermost dimension).
For any other contiguous tensor (e.g. shape(4, 5)
andaxis = 0
), elements that belong to different columns are interleaved in memory, so the current algorithm mixes rows and produces wrong results.The earlier review already highlighted this; the implementation has not changed.
Suggested fixes:
- Compute the true stride of the axis from
l1.stride()[axis]
and use it when walking the tensor, or- Call
.contiguous()
/.permute()
to move the axis to the last position before performing the scan.Please add a regression test on a 2-D tensor with
axis = 0
andaxis = 1
to verify correctness.mistralrs-quant/src/metal_kernels/scan_impl.metal (5)
53-55
:⚠️ Potential issue
CumProd::init
uses a float literal – breaks for integral/half types
static_cast<U>(1.0f)
converts through float, which is a narrowing / UB for some integer and half types and preventsconstexpr
evaluation.
Please use an integer literal instead:- static constexpr constant U init = static_cast<U>(1.0f); + static constexpr constant U init = static_cast<U>(1);
84-86
:⚠️ Potential issueInitial value for
CumMax
/CumMin
is wrong for floats
Limits<U>::min
andLimits<U>::max
return the smallest positive and largest positive values for floating-point types.
Uselowest()
/highest()
(ornumeric_limits<U>::lowest()
) to obtain the true extrema; otherwise negative inputs are mishandled.
245-253
:⚠️ Potential issuePointer underflow when
reverse == true
andoffset ≥ axis_size
in + axis_size - offset - N_READS
is evaluated before the bounds check in theelse
branch.
Whenoffset ≥ axis_size
this wraps around and produces an invalid device address → undefined behaviour / potential GPU fault.Refer to the previous review’s suggested fix that computes a clamped
base
before doing pointer arithmetic.
415-417
:⚠️ Potential issueIncorrect memory flag in
simdgroup_barrier
simdgroup_barrier
only acceptsmem_flags::mem_none
; passingmem_threadgroup
is rejected by the MSL compiler on macOS 14 and causes driver warnings on iOS 17.- simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_barrier(mem_flags::mem_none);
223-226
:⚠️ Potential issuePossible overflow:
simdgroup_sums[32]
is too small for large thread-groups
simd_groups = lsize.x / 32
, yet the scratch buffer is hard-coded to 32 entries.
If a future GPU (or a debug build) launches 1 024+ threads per thread-group, the array is written out of bounds.-threadgroup U simdgroup_sums[32]; +threadgroup U simdgroup_sums[simd_groups]; +static_assert(simd_groups <= 32, "Unexpectedly large simd_group count");(or allocate
MAX_SIMD_GROUPS
with an assertion).mistralrs-quant/src/metal_kernels/utils.metal (1)
988-994
: Bitwise '&' used instead of logical '&&' in template constraints.
enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
uses&
which happens to work but is semantically incorrect;&&
expresses intent and avoids subtle type-promotion surprises.- metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> + metal::enable_if_t<metal::is_integral_v<T> && !metal::is_signed_v<T>, T>- metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> + metal::enable_if_t<metal::is_integral_v<T> && metal::is_signed_v<T>, T>
🧹 Nitpick comments (10)
scripts/bench.py (4)
1-4
: Improve script documentation with usage instructionsThese commented commands provide execution examples but lack context. Consider converting them to proper docstring documentation with explanations of what each command does.
-# cargo run --release --features metal '--' --port 1234 --isq 8 --paged-attn --max-seqs 1000 plain -m ../hf_models/llama3.2_3b --max-seq-len 131072 -# cargo run --release --features metal '--' --port 1234 --paged-attn --max-seqs 1000 plain -m mlx-community/Mistral-7B-Instruct-v0.3-4bit --max-seq-len 131072 -# ./llama-server -m ../gguf_models/Llama-3.2-3B-Instruct-Q8_0.gguf -# mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit --port 8080 +""" +Benchmark script for asynchronous load testing against a local language model server. + +Example server startup commands: + +# mistral.rs with metal backend +cargo run --release --features metal '--' --port 1234 --isq 8 --paged-attn --max-seqs 1000 plain -m ../hf_models/llama3.2_3b --max-seq-len 131072 +cargo run --release --features metal '--' --port 1234 --paged-attn --max-seqs 1000 plain -m mlx-community/Mistral-7B-Instruct-v0.3-4bit --max-seq-len 131072 + +# llama.cpp server +./llama-server -m ../gguf_models/Llama-3.2-3B-Instruct-Q8_0.gguf + +# MLX server +mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit --port 8080 +"""
13-15
: Make benchmark parameters configurableConstants are hardcoded, limiting script flexibility. Consider making these configurable via command-line arguments.
-NUM_USERS = 8 -REQUESTS_PER_USER = 8 -PORT = 1234 +import argparse + +# Default configuration +DEFAULT_NUM_USERS = 8 +DEFAULT_REQUESTS_PER_USER = 8 +DEFAULT_PORT = 1234 + +# Parse command line arguments +def parse_args(): + parser = argparse.ArgumentParser(description="Benchmark a local LLM server with concurrent requests") + parser.add_argument("--users", type=int, default=DEFAULT_NUM_USERS, + help=f"Number of concurrent users (default: {DEFAULT_NUM_USERS})") + parser.add_argument("--requests", type=int, default=DEFAULT_REQUESTS_PER_USER, + help=f"Number of requests per user (default: {DEFAULT_REQUESTS_PER_USER})") + parser.add_argument("--port", type=int, default=DEFAULT_PORT, + help=f"Server port (default: {DEFAULT_PORT})") + return parser.parse_args()
42-43
: Make OpenAI client configuration more flexibleThe client is hardcoded with fixed values. Consider making the base URL configurable.
-# Use the async-capable client -client = AsyncOpenAI(api_key="foobar", base_url=f"http://localhost:{PORT}/v1/") +def create_client(port): + """Create an async OpenAI client configured for the local server.""" + return AsyncOpenAI( + api_key="foobar", # Dummy API key for local server + base_url=f"http://localhost:{port}/v1/" + )
81-102
: Enhance main function with more detailed metricsThe current implementation provides basic metrics. Consider adding more detailed statistics like percentiles and throughput over time.
async def main() -> None: """ Computes and prints overall average request time, total requests, and average T/s. """ - system_prompt = None # "You are a helpful assistant." - user_message = "Say hello!" + args = parse_args() + system_prompt = None # "You are a helpful assistant." + user_message = "Say hello!" + client = create_client(args.port) - tasks = [user_task(client, system_prompt, user_message) for _ in range(NUM_USERS)] + print(f"Starting benchmark with {args.users} users, {args.requests} requests per user...") + start_time = time.perf_counter() + + tasks = [user_task(client, system_prompt, user_message) for _ in range(args.users)] all_results_nested = await asyncio.gather(*tasks) all_results = [item for sublist in all_results_nested for item in sublist] + + total_time = time.perf_counter() - start_time total_requests = len(all_results) - total_time = sum(elapsed for _, elapsed, _ in all_results) + request_times = [elapsed for _, elapsed, _ in all_results if _ is not None] + successful_requests = len(request_times) + avg_time = sum(request_times) / successful_requests if successful_requests else 0.0 total_tokens = sum(tokens for _, _, tokens in all_results) - avg_time = total_time / total_requests if total_requests else 0.0 - avg_tps = total_tokens / total_time if total_time > 0 else 0.0 + overall_tps = total_tokens / total_time if total_time > 0 else 0.0 + + # Calculate percentiles if we have results + if request_times: + request_times.sort() + p50 = request_times[len(request_times) // 2] + p90 = request_times[int(len(request_times) * 0.9)] + p99 = request_times[int(len(request_times) * 0.99)] + else: + p50 = p90 = p99 = 0 print(f"Total requests: {total_requests}") + print(f"Successful requests: {successful_requests}") + print(f"Success rate: {successful_requests/total_requests*100:.2f}%") print(f"Average request time: {avg_time:.2f}s") + print(f"Percentiles: p50={p50:.2f}s, p90={p90:.2f}s, p99={p99:.2f}s") print(f"Total tokens: {total_tokens}") - print(f"Average tokens per second (T/s): {avg_tps:.2f}") + print(f"Overall tokens per second: {overall_tps:.2f} T/s") + print(f"Total benchmark time: {total_time:.2f}s")mistralrs-quant/src/metal_kernels/copy.metal (1)
5-17
: Macro naming explosion – consider normalising generated kernel identifiersThe two macro layers concatenate
tname
twice (instantiate_copy_same(itname ##itname, …)
andinstantiate_copy_all(itname ##… )
).
For several invocations this produces identifiers such asgg1_copyfloat16float16
,g1_copybool_bool_
, etc.These very long (and sometimes duplicated) names are legal, but:
- They become cumbersome to read in build logs and GPU profiling tools.
- They increase the probability of hitting the 1024-byte symbol limit on some platforms.
- Future greps /
nm
searches become noisier.If keeping the verbose name is not a hard requirement, you could spare a few characters:
- instantiate_copy_same(itname ##itname, itype) + instantiate_copy_same(itname, itype)and likewise drop the double prefix in
instantiate_copy_all
.
No functional change – purely a QoL / maintainability tweak.Also applies to: 37-51
mistralrs-quant/src/metal_kernels/quantized.metal (1)
1614-1623
: Minor: remove redundantthread
qualifierDeclaring
wl_ptrs/…
asthread const device uint8_t *
is legal, butthread
is implicit for local variables; omitting it shortens the type without changing semantics:-thread const device uint8_t *wl_ptrs[results_per_simdgroup]; +const device uint8_t *wl_ptrs[results_per_simdgroup];Same for
sl_ptrs
,bl_ptrs
.mistralrs-core/src/pipeline/paths.rs (1)
511-544
: Avoid cloning large fallback strings unnecessarily
chat_template_fallback.cloned()
creates an ownedString
even though the result is only used for a read.
Borrowing is sufficient and avoids an allocation:-match chat_template_fallback.cloned() { - Some(t) => { /* uses t */ } +match chat_template_fallback { + Some(t) => { /* uses *t */ }A tiny optimisation, but worth it in hot-start scenarios where templates are large.
mistralrs-quant/src/metal_kernels/scan.metal (1)
65-77
: Inconsistent kernel namingprod_bool__bool_
The generated host-name
prod_bool__bool_
contains a double underscore and a trailing underscore, unlike the other instantiations (e.g.prod_uint8_uint8
).
If the Rust side ever tries to build the kernel name programmatically (prod_bool_bool
), it will miss this variant.Confirm that the extra underscores are intentional; otherwise rename to
prod_bool_bool
.mistralrs-quant/src/metal_kernels/sort_impl.metal (1)
26-29
:Limits<T>::max
as sentinel breaks descending inputs for floating-point typesThe sentinel
init = Limits<T>::max
isFLT_MAX
/DBL_MAX
, not+∞
.
ForLessThan
this is acceptable, but if the comparison order is ever flipped (e.g. descending sort), the sentinel will incorrectly dominate the result set.Consider using
numeric_limits<T>::infinity()
or specialising for float types.mistralrs-quant/src/metal_kernels/copy_impl.metal (1)
211-214
: Pointer arithmetic ondevice
memory defeats alias analysis
src += src_offset; dst += dst_offset;
Modifying the raw pointer hides the original base address from the compiler, restricting optimisation opportunities and sometimes violating Metal’s “do not form derived pointers” guidance.
Consider keeping the bases immutable and offsetting in the index math instead:
- src += src_offset; - dst += dst_offset; … - dst[idx.y] = src[idx.x]; + dst[dst_offset + idx.y] = src[src_offset + idx.x];This is also consistent with the fixed-stride variants above.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (65)
Cargo.toml
(1 hunks)examples/server/chat.py
(1 hunks)mistralrs-core/src/dummy_paged_attention/scheduler.rs
(5 hunks)mistralrs-core/src/engine/logger.rs
(5 hunks)mistralrs-core/src/engine/mod.rs
(2 hunks)mistralrs-core/src/kv_cache/rotating_cache.rs
(1 hunks)mistralrs-core/src/kv_cache/single_cache.rs
(1 hunks)mistralrs-core/src/models/deepseek2.rs
(2 hunks)mistralrs-core/src/models/deepseek3.rs
(2 hunks)mistralrs-core/src/models/gemma.rs
(1 hunks)mistralrs-core/src/models/gemma2.rs
(1 hunks)mistralrs-core/src/models/llama.rs
(2 hunks)mistralrs-core/src/models/mistral.rs
(2 hunks)mistralrs-core/src/models/mixtral.rs
(2 hunks)mistralrs-core/src/models/phi2.rs
(2 hunks)mistralrs-core/src/models/phi3.rs
(1 hunks)mistralrs-core/src/models/phi3_5_moe.rs
(2 hunks)mistralrs-core/src/models/qwen2.rs
(2 hunks)mistralrs-core/src/models/qwen3.rs
(2 hunks)mistralrs-core/src/models/qwen3_moe.rs
(2 hunks)mistralrs-core/src/models/starcoder2.rs
(1 hunks)mistralrs-core/src/paged_attention/scheduler.rs
(5 hunks)mistralrs-core/src/pipeline/ggml.rs
(1 hunks)mistralrs-core/src/pipeline/gguf.rs
(1 hunks)mistralrs-core/src/pipeline/macros.rs
(2 hunks)mistralrs-core/src/pipeline/mod.rs
(2 hunks)mistralrs-core/src/pipeline/normal.rs
(1 hunks)mistralrs-core/src/pipeline/paths.rs
(5 hunks)mistralrs-core/src/pipeline/vision.rs
(1 hunks)mistralrs-core/src/sampler.rs
(4 hunks)mistralrs-core/src/scheduler/default_scheduler.rs
(7 hunks)mistralrs-core/src/scheduler/mod.rs
(2 hunks)mistralrs-core/src/vision_models/gemma3/config.rs
(1 hunks)mistralrs-core/src/vision_models/gemma3/text.rs
(1 hunks)mistralrs-core/src/vision_models/llama4/config.rs
(1 hunks)mistralrs-core/src/vision_models/llama4/text.rs
(1 hunks)mistralrs-core/src/vision_models/mllama/config.rs
(1 hunks)mistralrs-core/src/vision_models/mllama/text.rs
(1 hunks)mistralrs-core/src/vision_models/phi3/mod.rs
(2 hunks)mistralrs-core/src/vision_models/phi4/config.rs
(1 hunks)mistralrs-core/src/vision_models/phi4/mod.rs
(1 hunks)mistralrs-core/src/vision_models/qwen2_5_vl/config.rs
(1 hunks)mistralrs-core/src/vision_models/qwen2_5_vl/text.rs
(1 hunks)mistralrs-core/src/vision_models/qwen2vl/config.rs
(1 hunks)mistralrs-core/src/vision_models/qwen2vl/text.rs
(1 hunks)mistralrs-quant/build.rs
(3 hunks)mistralrs-quant/kernels/marlin/marlin_kernel.cu
(47 hunks)mistralrs-quant/src/lib.rs
(1 hunks)mistralrs-quant/src/metal_kernels/bf16.metal
(1 hunks)mistralrs-quant/src/metal_kernels/bitwise.metal
(1 hunks)mistralrs-quant/src/metal_kernels/copy.metal
(1 hunks)mistralrs-quant/src/metal_kernels/copy_impl.metal
(1 hunks)mistralrs-quant/src/metal_kernels/mod.rs
(2 hunks)mistralrs-quant/src/metal_kernels/quantized.metal
(2 hunks)mistralrs-quant/src/metal_kernels/scan.metal
(1 hunks)mistralrs-quant/src/metal_kernels/scan_impl.metal
(1 hunks)mistralrs-quant/src/metal_kernels/sort.metal
(1 hunks)mistralrs-quant/src/metal_kernels/sort_impl.metal
(1 hunks)mistralrs-quant/src/metal_kernels/utils.metal
(2 hunks)mistralrs-quant/src/metal_kernels/utils.rs
(1 hunks)mistralrs-quant/src/safetensors.rs
(1 hunks)mistralrs-quant/src/utils/mod.rs
(1 hunks)mistralrs-quant/src/utils/ops.rs
(5 hunks)scripts/bench.py
(1 hunks)scripts/convert_awq_marlin.py
(8 hunks)
✅ Files skipped from review due to trivial changes (20)
- examples/server/chat.py
- mistralrs-quant/src/metal_kernels/bitwise.metal
- mistralrs-core/src/models/phi3.rs
- mistralrs-core/src/vision_models/gemma3/config.rs
- mistralrs-core/src/pipeline/vision.rs
- mistralrs-core/src/vision_models/qwen2vl/config.rs
- mistralrs-core/src/models/gemma.rs
- mistralrs-core/src/models/gemma2.rs
- mistralrs-core/src/vision_models/mllama/config.rs
- mistralrs-core/src/vision_models/llama4/config.rs
- mistralrs-core/src/models/starcoder2.rs
- mistralrs-core/src/vision_models/phi4/config.rs
- mistralrs-core/src/vision_models/qwen2_5_vl/config.rs
- Cargo.toml
- mistralrs-quant/src/lib.rs
- mistralrs-core/src/pipeline/ggml.rs
- mistralrs-core/src/pipeline/gguf.rs
- mistralrs-core/src/pipeline/normal.rs
- mistralrs-quant/kernels/marlin/marlin_kernel.cu
- mistralrs-quant/src/metal_kernels/sort.metal
🚧 Files skipped from review as they are similar to previous changes (4)
- mistralrs-core/src/pipeline/mod.rs
- mistralrs-quant/build.rs
- mistralrs-quant/src/utils/mod.rs
- mistralrs-quant/src/metal_kernels/bf16.metal
🧰 Additional context used
🧬 Code Graph Analysis (6)
mistralrs-core/src/kv_cache/rotating_cache.rs (1)
mistralrs-core/src/kv_cache/single_cache.rs (1)
all_data
(41-43)
mistralrs-core/src/dummy_paged_attention/scheduler.rs (3)
mistralrs-core/src/scheduler/mod.rs (1)
schedule
(55-55)mistralrs-core/src/scheduler/default_scheduler.rs (2)
schedule
(206-301)schedule
(311-315)mistralrs-core/src/paged_attention/scheduler.rs (2)
schedule
(67-239)schedule
(369-373)
mistralrs-core/src/pipeline/macros.rs (1)
mistralrs-core/src/pipeline/paths.rs (1)
get_xlora_paths
(55-307)
mistralrs-core/src/sampler.rs (6)
mistralrs-core/src/pipeline/mod.rs (7)
logits
(433-436)logits
(443-455)logits
(465-480)logits
(581-584)logits
(591-601)logits
(611-626)tokenizer
(153-153)mistralrs-core/src/sequence.rs (1)
return_logprobs
(785-787)mistralrs-core/src/pipeline/ggml.rs (1)
tokenizer
(500-502)mistralrs-core/src/pipeline/gguf.rs (1)
tokenizer
(661-663)mistralrs-core/src/pipeline/normal.rs (1)
tokenizer
(960-962)mistralrs-core/src/pipeline/vision.rs (1)
tokenizer
(819-821)
mistralrs-core/src/kv_cache/single_cache.rs (1)
mistralrs-core/src/kv_cache/rotating_cache.rs (1)
all_data
(47-49)
mistralrs-core/src/pipeline/paths.rs (2)
mistralrs-core/src/model_loader.rs (3)
quantized_filename
(337-340)quantized_filename
(362-365)quantized_filename
(395-398)mistralrs-core/src/toml_selector.rs (3)
quantized_filename
(726-729)quantized_filename
(753-756)quantized_filename
(786-789)
🪛 Ruff (0.11.9)
scripts/convert_awq_marlin.py
13-13: Undefined name List
(F821)
16-16: Undefined name List
(F821)
⏰ Context from checks skipped due to timeout of 90000ms (5)
- GitHub Check: Test Suite (macOS-latest, stable)
- GitHub Check: Test Suite (windows-latest, stable)
- GitHub Check: Docs
- GitHub Check: Test Suite (ubuntu-latest, stable)
- GitHub Check: Check (windows-latest, stable)
🔇 Additional comments (74)
mistralrs-core/src/kv_cache/rotating_cache.rs (1)
47-48
: Improved API ergonomics with better return typeThe change from returning
&Option<Tensor>
toOption<&Tensor>
is a significant improvement. Usingself.all_data.as_ref()
is the idiomatic Rust approach for this conversion, making the API more intuitive for callers who can now directly use the option without additional dereferencing.mistralrs-core/src/kv_cache/single_cache.rs (1)
41-42
: Improved API ergonomics with better return typeThis change correctly transforms the return type from
&Option<Tensor>
toOption<&Tensor>
, making the API more ergonomic. The implementation usingself.all_data.as_ref()
follows Rust best practices. This change maintains consistency with the parallel modification inrotating_cache.rs
.scripts/convert_awq_marlin.py (5)
79-81
: Improved function signature format.The reformatted function signature for
marlin_zero_points
is now more readable with parameters on separate lines.
102-104
: Improved function signature format.The reformatted function signature for
awq_to_marlin_zero_points
is now more readable with parameters on separate lines.
56-60
: Improved assertion message formatting.The assertion message is now properly enclosed in parentheses, which aligns with Python style best practices for multi-line statements.
198-206
: Consistent assertion message formatting.All assertion messages are now properly enclosed in parentheses, which improves code consistency and readability.
177-191
: Improved argument parser formatting.The argument parser configuration is now more consistently formatted with clear indentation, making it easier to read.
scripts/bench.py (2)
18-40
: Consider simplifying the log_response functionThe function is comprehensive but verbose. For benchmarking purposes, a more concise logging approach might be sufficient.
104-105
: 🛠️ Refactor suggestionUpdate script execution to use the new argument parsing
Update the main entry point to use the new command-line argument parsing.
if __name__ == "__main__": asyncio.run(main())Likely an incorrect or invalid review comment.
mistralrs-core/src/vision_models/qwen2vl/text.rs (1)
395-396
: Good update to use the model's quantization configurationUpdating the
lm_head
initialization to use the model's quantization configuration ensures consistent quantization handling across all model components. This change properly propagates the configuration instead of using a hardcodedNone
.mistralrs-core/src/vision_models/phi4/mod.rs (1)
412-413
: Good update to use the model's quantization configurationUpdating the
lm_head
initialization to use the model's quantization configuration ensures consistent quantization handling across all model components. This alignment with other model parts improves consistency in the quantization behavior.mistralrs-core/src/vision_models/mllama/text.rs (1)
580-581
: Good update to use the model's quantization configurationUpdating the
lm_head
initialization to use the model's quantization configuration ensures consistent quantization handling across all model components. This change correctly propagates the quantization settings to all relevant layers.mistralrs-core/src/vision_models/gemma3/text.rs (1)
494-494
: Improved quantization support for language model head.This change properly propagates the quantization configuration to the language model head, ensuring consistent quantization behavior across all model components. Previously, the language model head was always initialized with
None
for quantization, preventing it from being quantized even when other layers were.mistralrs-core/src/vision_models/llama4/text.rs (1)
614-614
: Enabled quantization for language model head.This change properly propagates the quantization configuration to the language model head, ensuring consistent quantization behavior across all model components. This aligns with similar changes across multiple model implementations in the codebase.
mistralrs-core/src/vision_models/qwen2_5_vl/text.rs (1)
399-399
: Applied consistent quantization to language model head.This change ensures the language model head properly receives the model's quantization configuration rather than hardcoding
None
. This allows the head layer to be quantized consistently with the rest of the model.mistralrs-core/src/models/phi3_5_moe.rs (2)
49-50
: Added serde alias for improved config compatibility.The serde alias
"quantization"
allows the model to be loaded from JSON configurations that use either "quantization_config" or "quantization" as the field name, providing backward compatibility with different model serialization formats.
661-661
: Propagated quantization config to language model head.This change enables proper quantization of the language model head using the same configuration as the rest of the model, ensuring consistent quantization behavior.
mistralrs-core/src/models/mixtral.rs (2)
48-49
: LGTM: Added serde alias "quantization" for the quantization_config field.This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.
584-584
: LGTM: Now propagating quantization config to language model head.This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded
None
.mistralrs-core/src/models/qwen3.rs (2)
59-60
: LGTM: Added serde alias "quantization" for the quantization_config field.This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.
492-492
: LGTM: Now propagating quantization config to language model head.This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded
None
.mistralrs-core/src/models/llama.rs (2)
47-48
: LGTM: Added serde alias "quantization" for the quantization_config field.This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.
385-385
: LGTM: Now propagating quantization config to language model head.This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded
None
.mistralrs-core/src/models/deepseek2.rs (2)
96-97
: LGTM: Added serde alias "quantization" for the quantization_config field.This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.
792-792
: LGTM: Now propagating quantization config to language model head.This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded
None
.mistralrs-quant/src/safetensors.rs (1)
209-209
: Code simplification improves readability.Simplified the tensor loading logic by directly delegating to the underlying
load
implementation with the provided dtype. This removes an unnecessary intermediate step where a tensor was first loaded and then potentially converted to a requested dtype.mistralrs-core/src/vision_models/phi3/mod.rs (2)
78-79
: Adds serialization flexibility with serde alias.Adding the
alias = "quantization"
attribute allows the model to deserialize configuration that uses eitherquantization_config
orquantization
as the field name. This improves compatibility with different model file formats.
1030-1031
: Propagates quantization configuration to language model head.The language model head now receives the proper quantization configuration instead of
None
. This ensures consistent quantization behavior across all model components, including the output layer.mistralrs-core/src/engine/mod.rs (2)
17-17
: Makes IntervalLogger publicly accessible.Changed the IntervalLogger import to be public, enabling its use outside the engine module. This supports the integration of logging into other components like the scheduler.
175-175
: Passes logger to scheduler for enhanced metrics tracking.Updated to pass the engine's logger to the scheduler, enabling the scheduler to update metrics like number of running and waiting sequences. This change aligns with the updated scheduler interface.
mistralrs-core/src/scheduler/mod.rs (2)
9-9
: Adds import for IntervalLogger.Imported IntervalLogger from the engine module to support the updated scheduler interface.
55-55
: Updates Scheduler trait to include logging capability.The
schedule
method now requires an IntervalLogger reference, enabling all scheduler implementations to update metrics about running and waiting sequences. This enhances observability of the scheduling process.mistralrs-core/src/models/qwen3_moe.rs (2)
57-58
: Added serde alias for backward compatibility.The
#[serde(alias = "quantization")]
attribute allows the field to be deserialized from JSON/YAML with either "quantization_config" or "quantization" as the key, improving backward compatibility with existing model configurations.
695-696
: Propagating quantization config to language model head.The change ensures that quantization settings are consistently applied to the language model head (
lm_head
), where previously a hardcodedNone
was used. This allows for proper quantization across all model components.mistralrs-core/src/models/qwen2.rs (2)
43-44
: Added serde alias for backward compatibility.The
#[serde(alias = "quantization")]
attribute allows deserializing the quantization configuration from either "quantization_config" or "quantization" field names, enhancing compatibility with various model configuration formats.
423-424
: Propagating quantization config to language model head.This change replaces a hardcoded
None
with&cfg.quantization_config
, ensuring that the language model head properly respects the quantization settings provided to the model.mistralrs-core/src/models/deepseek3.rs (2)
96-97
: Added serde alias for backward compatibility.The
#[serde(alias = "quantization")]
attribute enables deserialization from either "quantization_config" or "quantization" field names, improving compatibility with different model configuration formats.
845-846
: Propagating quantization config to language model head.This change ensures that the language model head (
lm_head
) consistently uses the same quantization configuration as the rest of the model components, replacing a hardcodedNone
reference.mistralrs-core/src/engine/logger.rs (7)
15-16
: Added tracking for running and waiting sequences.These new atomic counters enable monitoring the number of sequences in different states in the scheduler, which is valuable for debugging and performance monitoring.
26-28
: Initialized atomic counters for sequence tracking.Proper initialization of the atomic counters for tracking running and waiting sequences, consistent with the initialization pattern used for other counters.
33-34
: Added thread-local clones of atomic counters.These thread-local clones enable the background logging thread to safely access the sequence state counters without risking thread safety issues.
46-47
: Reading scheduler statistics in the logging thread.This code loads the current values of running and waiting sequences to be displayed in periodic log messages, providing runtime visibility into scheduler state.
51-52
: Enhanced logging output with scheduler information.The log format now includes the count of running and waiting sequences, providing runtime visibility into the scheduler's workload and potential bottlenecks.
64-65
: Added fields to struct for tracking sequence counts.These fields complete the implementation of sequence tracking by storing the atomic counters in the
IntervalLogger
struct.
86-92
: Added public setters for scheduler metrics.These methods allow the scheduler to update the number of running and waiting sequences, completing the integration between the scheduler and logging components.
mistralrs-core/src/models/mistral.rs (2)
45-45
: Adds serde deserialization compatibility for legacy field name.Adding the
alias = "quantization"
attribute allows the field to be deserialized from JSON that uses either"quantization_config"
or the legacy"quantization"
name, improving backward compatibility.
451-452
: Enables quantization for the language model head layer.Now the
lm_head
initialization properly uses the model's quantization configuration instead of passingNone
. This change ensures consistent quantization behavior across all model components.mistralrs-core/src/models/phi2.rs (2)
54-55
: Adds serde deserialization compatibility for legacy field name.Adding the
alias = "quantization"
attribute allows the field to be deserialized from JSON that uses either"quantization_config"
or the legacy"quantization"
name, improving backward compatibility.
530-531
: Enables quantization for the language model head layer.Now the
lm_head
initialization properly uses the model's quantization configuration instead of passingNone
. This change ensures consistent quantization behavior across all model components.mistralrs-core/src/pipeline/macros.rs (4)
99-101
: Updates argument passing style to match function signatures.Changed from passing references to Options (
&Option<T>
) to passing optional references (Option<&T>
) using.as_ref()
, aligning with updated function signatures inpaths.rs
.
107-112
: Updates xlora-related parameters to use optional references.Changed from passing references to Options to passing optional references using
.as_ref()
, creating a more idiomatic API design and matching updated function signatures.
286-288
: Updates argument passing style for model paths in GGUF context.Changed from passing
Option<T>
directly to wrapping values inSome(&T)
, ensuring type compatibility with updated function signatures expectingOption<&T>
parameters.
295-300
: Updates xlora-related parameter passing in GGUF context.Changed from passing Options directly to using
.as_ref()
for consistent argument passing style that aligns with the updated function signatures inpaths.rs
.mistralrs-core/src/paged_attention/scheduler.rs (5)
20-21
: Adds IntervalLogger import for enhanced monitoring.Added import for the
IntervalLogger
that will be used to track and report scheduler metrics.
67-67
: Updates scheduler interface to integrate logging.The
schedule
method now accepts a reference to anIntervalLogger
parameter to enable reporting of scheduler metrics.
123-125
: Adds logging for scheduler status after promotion phase.After sequences are either scheduled or ignored during promotion from waiting to running, the logger is updated with current counts of running and waiting sequences, enhancing observability.
230-232
: Adds logging for scheduler status after main scheduling logic.The logger is updated with the final counts of running and waiting sequences after all scheduling decisions have been made, providing visibility into the scheduler's state.
369-372
: Updates Scheduler trait implementation to propagate logger.The trait implementation now passes the logger through to the concrete implementation, maintaining consistent logging behavior.
mistralrs-core/src/dummy_paged_attention/scheduler.rs (1)
67-67
: LGTM! Consistent logger integrationThe IntervalLogger integration is well-implemented, with appropriate updates at key points in the scheduling pipeline. This change aligns perfectly with the pattern established in other scheduler implementations.
Also applies to: 123-124, 230-231, 369-371
mistralrs-core/src/scheduler/default_scheduler.rs (1)
206-206
: LGTM! Thorough logger integrationThe IntervalLogger integration is well-implemented with updates at all key decision points in the scheduling pipeline. The logger calls are strategically placed to ensure accurate reporting of running and waiting sequence counts throughout the execution path.
Also applies to: 218-219, 233-234, 248-249, 284-286, 311-313
mistralrs-core/src/sampler.rs (1)
794-802
: LGTM! Feature-gated fast samplerThe conditional return based on the
metal
feature flag is a good approach. This properly preserves the original sampling path for non-metal builds while enabling the optimized path for metal, avoiding the unreachable code issue that would occur with an unconditional return.mistralrs-quant/src/metal_kernels/utils.rs (1)
95-137
: LGTM! Well-implemented thread-group sizing functionThe
get_block_dims
function provides efficient 3D thread-group sizing with power-of-two thread counts while respecting dimension extents. This is a good port of MLX's approach and will help optimize kernel launch configurations.mistralrs-quant/src/metal_kernels/copy.metal (1)
52-63
: LGTM – exhaustive instantiation achievedThe final block instantiates the kernels for every primitive/bfloat16/half type across the three common group sizes.
No gaps spotted, order is consistent with the rest of the code-base.mistralrs-quant/src/metal_kernels/quantized.metal (1)
1598-1644
: Pointer-increment logic: verify overflow on partial K blocksThe new cache-friendly loop assumes that
in_vec_size
is an exact multiple ofblock_size
(values_per_thread * SIMD_SIZE
).
If this is not guaranteed, the final iteration will read past the end ofx
,w
,scales
, andbiases
.The slow path (
qmv_impl
) handles the “tail” withload_vector_safe
, butqmv_fast_impl
does not.Please double-check callers to ensure the fast kernel is only launched with aligned sizes, or add a guarded tail loop similar to the reference implementation.
mistralrs-quant/src/metal_kernels/utils.metal (10)
6-7
: Good addition of#pragma once
directive.This is a standard practice to prevent multiple inclusions of this header file, which is especially important for utility headers that might be included in multiple translation units.
616-622
: Clean template struct implementation for work distribution.The
WorkPerThread
struct elegantly scales work based on data type size, which is a good practice for optimizing GPU workloads.
627-672
: Well-structured type limits implementation.The
Limits
struct and its specializations provide a clean abstraction for accessing numeric limits of different types, with proper handling for both integral and floating-point types.
678-767
: Comprehensive indexing utilities for array traversal.The indexing utilities provide efficient functions for converting linear indices to memory locations for arrays of different dimensions, which is crucial for GPU kernel performance.
773-847
: Well-designed recursive template for multi-dimensional iteration.The
LoopedElemToLoc
template provides an elegant solution for traversing multi-dimensional arrays, with specialized implementations for different dimensionalities.
869-897
: Numerically stable calculation utilities.The
ceildiv
andlog1p
implementations follow best practices for numerical stability, including proper handling of edge cases.
903-956
: Effective SIMD shuffle operations for 64-bit types and booleans.These implementations properly wrap Metal's SIMD intrinsics, extending them to handle 64-bit integers and boolean types.
957-1087
: Comprehensive set of templated operator functors.These functors provide a clean, uniform interface for arithmetic, comparison, and logical operations, with appropriate specializations for different types.
1090-1107
: Efficient power implementation for integral types.The
Power
functor's specialization for integral types uses the efficient binary exponentiation algorithm (also known as exponentiation by squaring), which is optimal for computing integer powers.
1153-1159
: Useful kernel instantiation macro.The
instantiate_kernel
macro provides a convenient way to instantiate templated Metal kernels with specified host names and template parameters.
scale_perm: List[int] = [] | ||
for i in range(8): | ||
scale_perm.extend([i + 8 * j for j in range(8)]) | ||
scale_perm_single: List[int] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing import for List
from the typing
module.
The code uses List
type annotations, but the typing
module that provides List
isn't imported.
Add the following import at the top of the file:
import os
import shutil
+from typing import List
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
scale_perm: List[int] = [] | |
for i in range(8): | |
scale_perm.extend([i + 8 * j for j in range(8)]) | |
scale_perm_single: List[int] = [] | |
import os | |
import shutil | |
from typing import List | |
# ... rest of scripts/convert_awq_marlin.py ... |
🧰 Tools
🪛 Ruff (0.11.9)
13-13: Undefined name List
(F821)
16-16: Undefined name List
(F821)
🤖 Prompt for AI Agents
In scripts/convert_awq_marlin.py around lines 13 to 16, the code uses the List
type annotation but does not import List from the typing module. To fix this,
add the import statement "from typing import List" at the top of the file to
ensure the List type is recognized.
async def timed_chat(client: AsyncOpenAI, messages): | ||
""" | ||
Send one chat completion request and return (completion, elapsed_seconds, completion_tokens). | ||
""" | ||
start = time.perf_counter() | ||
completion = await client.chat.completions.create( | ||
model="mlx-community/Mistral-7B-Instruct-v0.3-4bit", | ||
messages=messages, | ||
max_tokens=256, | ||
frequency_penalty=1.0, | ||
top_p=0.1, | ||
temperature=0, | ||
) | ||
elapsed = time.perf_counter() - start | ||
# Safely get number of completion tokens, default to 0 if missing | ||
completion_tokens = getattr(completion.usage, "completion_tokens", 0) | ||
return completion, elapsed, completion_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Parameterize the timed_chat function
Model name and generation parameters are hardcoded, limiting the benchmark's flexibility for testing different configurations.
-async def timed_chat(client: AsyncOpenAI, messages):
+async def timed_chat(
+ client: AsyncOpenAI,
+ messages,
+ model_name="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
+ max_tokens=256,
+ frequency_penalty=1.0,
+ top_p=0.1,
+ temperature=0
+):
"""
Send one chat completion request and return (completion, elapsed_seconds, completion_tokens).
"""
start = time.perf_counter()
completion = await client.chat.completions.create(
- model="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
+ model=model_name,
messages=messages,
- max_tokens=256,
- frequency_penalty=1.0,
- top_p=0.1,
- temperature=0,
+ max_tokens=max_tokens,
+ frequency_penalty=frequency_penalty,
+ top_p=top_p,
+ temperature=temperature,
)
elapsed = time.perf_counter() - start
# Safely get number of completion tokens, default to 0 if missing
completion_tokens = getattr(completion.usage, "completion_tokens", 0)
return completion, elapsed, completion_tokens
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
async def timed_chat(client: AsyncOpenAI, messages): | |
""" | |
Send one chat completion request and return (completion, elapsed_seconds, completion_tokens). | |
""" | |
start = time.perf_counter() | |
completion = await client.chat.completions.create( | |
model="mlx-community/Mistral-7B-Instruct-v0.3-4bit", | |
messages=messages, | |
max_tokens=256, | |
frequency_penalty=1.0, | |
top_p=0.1, | |
temperature=0, | |
) | |
elapsed = time.perf_counter() - start | |
# Safely get number of completion tokens, default to 0 if missing | |
completion_tokens = getattr(completion.usage, "completion_tokens", 0) | |
return completion, elapsed, completion_tokens | |
async def timed_chat( | |
client: AsyncOpenAI, | |
messages, | |
model_name="mlx-community/Mistral-7B-Instruct-v0.3-4bit", | |
max_tokens=256, | |
frequency_penalty=1.0, | |
top_p=0.1, | |
temperature=0 | |
): | |
""" | |
Send one chat completion request and return (completion, elapsed_seconds, completion_tokens). | |
""" | |
start = time.perf_counter() | |
completion = await client.chat.completions.create( | |
model=model_name, | |
messages=messages, | |
max_tokens=max_tokens, | |
frequency_penalty=frequency_penalty, | |
top_p=top_p, | |
temperature=temperature, | |
) | |
elapsed = time.perf_counter() - start | |
# Safely get number of completion tokens, default to 0 if missing | |
completion_tokens = getattr(completion.usage, "completion_tokens", 0) | |
return completion, elapsed, completion_tokens |
🤖 Prompt for AI Agents
In scripts/bench.py around lines 46 to 62, the timed_chat function has hardcoded
model name and generation parameters, reducing flexibility. Modify the function
signature to accept model name and generation parameters as arguments with
default values, then use these parameters inside the function instead of
hardcoded values. This will allow benchmarking with different configurations
without changing the function code.
async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str): | ||
""" | ||
Returns list of (completion, elapsed_seconds, completion_tokens). | ||
""" | ||
results = [] | ||
base_messages = [] | ||
if system_prompt: | ||
base_messages.append({"role": "system", "content": system_prompt}) | ||
|
||
for _ in range(REQUESTS_PER_USER): | ||
messages = base_messages + [{"role": "user", "content": user_message}] | ||
completion, elapsed, completion_tokens = await timed_chat(client, messages) | ||
results.append((completion, elapsed, completion_tokens)) | ||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling to user_task
The function lacks error handling, which could lead to unhandled exceptions if requests fail.
async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str):
"""
Returns list of (completion, elapsed_seconds, completion_tokens).
"""
results = []
base_messages = []
if system_prompt:
base_messages.append({"role": "system", "content": system_prompt})
for _ in range(REQUESTS_PER_USER):
- messages = base_messages + [{"role": "user", "content": user_message}]
- completion, elapsed, completion_tokens = await timed_chat(client, messages)
- results.append((completion, elapsed, completion_tokens))
+ try:
+ messages = base_messages + [{"role": "user", "content": user_message}]
+ completion, elapsed, completion_tokens = await timed_chat(client, messages)
+ results.append((completion, elapsed, completion_tokens))
+ except Exception as e:
+ print(f"Error during request: {e}")
+ # Add a failed result with zero tokens and high elapsed time to indicate failure
+ results.append((None, 10.0, 0))
return results
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str): | |
""" | |
Returns list of (completion, elapsed_seconds, completion_tokens). | |
""" | |
results = [] | |
base_messages = [] | |
if system_prompt: | |
base_messages.append({"role": "system", "content": system_prompt}) | |
for _ in range(REQUESTS_PER_USER): | |
messages = base_messages + [{"role": "user", "content": user_message}] | |
completion, elapsed, completion_tokens = await timed_chat(client, messages) | |
results.append((completion, elapsed, completion_tokens)) | |
return results | |
async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str): | |
""" | |
Returns list of (completion, elapsed_seconds, completion_tokens). | |
""" | |
results = [] | |
base_messages = [] | |
if system_prompt: | |
base_messages.append({"role": "system", "content": system_prompt}) | |
for _ in range(REQUESTS_PER_USER): | |
try: | |
messages = base_messages + [{"role": "user", "content": user_message}] | |
completion, elapsed, completion_tokens = await timed_chat(client, messages) | |
results.append((completion, elapsed, completion_tokens)) | |
except Exception as e: | |
print(f"Error during request: {e}") | |
# Add a failed result with zero tokens and high elapsed time to indicate failure | |
results.append((None, 10.0, 0)) | |
return results |
🤖 Prompt for AI Agents
In scripts/bench.py around lines 65 to 78, the user_task function does not
handle errors, risking unhandled exceptions during requests. Add try-except
blocks around the await timed_chat call to catch exceptions, log or handle them
appropriately, and ensure the function continues or fails gracefully without
crashing.
uint simd_gid [[simdgroup_index_in_simdgroup]], | ||
uint simd_lid [[thread_index_in_simdgroup]]) { | ||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attribute typo breaks compilation
simdgroup_index_in_simdgroup
is not a valid Metal shading language attribute – the correct attribute for “SIMD-group index inside the current thread-group” is simdgroup_index_in_threadgroup
.
This line will make the entire Metal library fail to compile on Apple tools ≥ Xcode 9.
- uint simd_gid [[simdgroup_index_in_simdgroup]],
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
Please adjust here and in any other newly-added kernels to restore successful compilation.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
uint simd_gid [[simdgroup_index_in_simdgroup]], | |
uint simd_lid [[thread_index_in_simdgroup]]) { | |
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; | |
uint simd_gid [[simdgroup_index_in_threadgroup]], | |
uint simd_lid [[thread_index_in_simdgroup]]) { | |
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; |
🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/quantized.metal at lines 1578 to 1580, the
attribute simdgroup_index_in_simdgroup is incorrect and causes compilation
failure. Replace simdgroup_index_in_simdgroup with the correct attribute
simdgroup_index_in_threadgroup to fix the compilation error. Also, check any
other new kernel code for the same typo and correct it similarly.
match quantized_filename { | ||
Some(names) => { | ||
let id = quantized_model_id.as_ref().unwrap(); | ||
let id = quantized_model_id.unwrap(); | ||
let mut files = Vec::new(); | ||
|
||
for name in names { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unwrap()
may panic when only filenames are supplied
Inside get_model_paths
the Some(names)
arm unconditionally unwraps quantized_model_id
:
let id = quantized_model_id.unwrap();
If the caller provides an explicit quantized_filename
list but forgets the corresponding quantized_model_id
, the process will panic instead of returning an error.
Consider returning an anyhow::Error
instead:
- let id = quantized_model_id.unwrap();
+ let id = quantized_model_id
+ .ok_or_else(|| anyhow::anyhow!("`quantized_model_id` must be supplied when `quantized_filename` is provided"))?;
This keeps the function panic-free and surfaces configuration errors gracefully.
🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/paths.rs around lines 318 to 323, the code
unconditionally unwraps quantized_model_id inside the Some(names) match arm,
which can cause a panic if quantized_model_id is None. Modify the code to check
if quantized_model_id is Some before unwrapping; if it is None, return an
anyhow::Error indicating the missing quantized_model_id instead of panicking.
This change will make the function panic-free and handle configuration errors
gracefully.
impl SortOp for Tensor { | ||
fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> { | ||
if self.device().is_cpu() || self.device().is_cuda() { | ||
return self.arg_sort_last_dim(true); | ||
} | ||
self.apply_op1_no_bwd(&ArgSort { | ||
axis: axis.to_index(self.shape(), "argsort")?, | ||
}) | ||
} | ||
|
||
fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> { | ||
if self.device().is_cpu() || self.device().is_cuda() { | ||
return Ok(self.sort_last_dim(true)?.0); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
fast_argsort_asc
/ fast_sort_asc
ignore the supplied axis
when running on CPU/CUDA
Both helpers immediately fall back to arg_sort_last_dim
/ sort_last_dim
, which always act on the last dimension, even if the caller asked for another axis.
This yields silently-wrong results for any multi-dimensional tensor whose last dimension is not the sort axis.
- if self.device().is_cpu() || self.device().is_cuda() {
- return self.arg_sort_last_dim(true);
- }
+ if self.device().is_cpu() || self.device().is_cuda() {
+ // Candle already provides arg_sort()/sort() that respect an axis.
+ // Preserve the public behaviour by forwarding the user-requested axis.
+ let axis = axis.to_index(self.shape(), "argsort")?;
+ return self.arg_sort(axis, true); // or `self.sort(axis, true)?`
+ }
(The same change is required for fast_sort_asc
a few lines below).
Please ensure the accompanying unit-tests cover non-trailing axes so the issue is caught automatically in the future.
🤖 Prompt for AI Agents
In mistralrs-quant/src/utils/ops.rs around lines 1091 to 1104, the methods
fast_argsort_asc and fast_sort_asc ignore the supplied axis parameter when
running on CPU or CUDA by always sorting along the last dimension. To fix this,
modify these methods to use the provided axis for sorting instead of defaulting
to the last dimension, ensuring the axis is converted properly using
axis.to_index with the tensor shape. Additionally, update or add unit tests to
cover sorting on non-last axes to prevent this issue from recurring.
auto idx = elem_to_loc_2_nd<IdxT>({N * index.x, index.y, index.z}, src_shape, | ||
src_strides, dst_strides, ndim); | ||
if (N == 1) { | ||
dst[idx.y] = static_cast<U>(src[idx.x]); | ||
return; | ||
} | ||
IdxT src_xstride = src_strides[ndim - 1]; | ||
IdxT dst_xstride = dst_strides[ndim - 1]; | ||
auto xshape = src_shape[ndim - 1]; | ||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { | ||
dst[idx.y] = static_cast<U>(src[idx.x]); | ||
idx.x += src_xstride; | ||
idx.y += dst_xstride; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing bounds check in strided gather→gather copies
copy_gg
writes
dst[idx.y] = static_cast<U>(src[idx.x]);
inside a loop where both idx.x
and idx.y
are manually incremented.
Neither value is compared against the logical src/dst sizes, so a malformed shape/stride pair can lead to OOB writes and kernel crashes.
Unless you have stronger invariants on the Rust side, protect the loop:
if (idx.x < src_len && idx.y < dst_len) { … }
Even a single device_assert
would greatly improve debuggability.
🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal around lines 143 to 156,
the loop writing to dst and reading from src increments idx.x and idx.y without
bounds checks, risking out-of-bounds memory access. To fix this, add conditional
checks inside the loop to ensure idx.x is less than the source length and idx.y
is less than the destination length before performing the copy. Optionally,
include device_assert statements to catch violations during debugging. This will
prevent OOB writes and improve kernel stability.
for (int i = 0; i < N && (index + i) < size; ++i) { | ||
dst[index + i] = static_cast<U>(src[0]); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid signed/unsigned comparisons to prevent subtle out-of-bounds bugs
All inner loops use int i
while the limit expressions (index
, size
, xshape
) are uint
or int64_t
.
Because of the usual arithmetic conversions the compiler silently casts operands back and forth, which can bite once size ≥ 2³¹
or when negative intermediates are introduced.
A minimal change keeps the code safe and helps the optimiser:
-for (int i = 0; i < N && (index + i) < size; ++i) {
+for (uint i = 0u; i < uint(N) && (index + i) < size; ++i) {
Please repeat the same change in every loop highlighted above (copy_s*
, copy_v*
, copy_g
, copy_gg
, copy_gg_dynamic
).
(The template parameter N
is still int
; only the loop counter needs to be unsigned.)
Also applies to: 18-20, 29-31, 40-42, 96-99, 152-156, 222-226
🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal at lines 8 to 10 and
similarly at lines 18-20, 29-31, 40-42, 96-99, 152-156, and 222-226, the loop
counter variable is declared as a signed int while the loop limit variables are
unsigned or int64_t, causing unsafe signed/unsigned comparisons. Change the loop
counter type from int to an appropriate unsigned integer type to match the limit
variables, ensuring safe comparisons and preventing potential out-of-bounds
errors. Apply this change consistently to all loops mentioned in the comment.
auto src_idx = elem_to_loc<IdxT>({N * index.x, index.y, index.z}, src_shape, | ||
src_strides, ndim); | ||
if (N == 1) { | ||
IdxT dst_idx = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Brace-initialiser may fail to compile on some Metal toolchains
elem_to_loc
is called with {N * index.x, index.y, index.z}
which is not guaranteed to deduce to a uint3
/int3
in MSL.
Several versions of Xcode/metal-clang reject this syntax.
Be explicit to avoid portability issues:
-auto src_idx = elem_to_loc<IdxT>({N * index.x, index.y, index.z},
+auto src_idx = elem_to_loc<IdxT>(uint3(N * index.x, index.y, index.z),
src_shape, src_strides, ndim);
Also applies to: 95-97
🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal around lines 85 to 88 and
also lines 95 to 97, the call to elem_to_loc uses a brace-initializer {N *
index.x, index.y, index.z} which may not compile on some Metal toolchains due to
type deduction issues. To fix this, explicitly construct a uint3 or int3
(matching the expected type) with these values instead of using a
brace-initializer, ensuring compatibility across Metal versions.
uint2 grid_dim [[threads_per_grid]]) { | ||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); | ||
for (int i = 0; i < N && (offset + i) < size; ++i) { | ||
dst[offset + i] = static_cast<U>(src[0]); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential 32-bit overflow in 2-D kernels
offset
is computed as
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
index.x
and grid_dim.x
are uint
, so the addition is performed in 32-bit before it is promoted to int64_t
by auto
.
With large grids (≥ 65 536 × 65 536) this silently wraps and corrupts memory.
Safer version:
-auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
+auto offset = IdxT(N) * (IdxT(index.x) + IdxT(grid_dim.x) * IdxT(index.y));
where IdxT
is the same 64-bit type used by your helpers.
Also applies to: 39-42
🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal around lines 27 to 32, the
calculation of offset uses 32-bit arithmetic for the addition of index.x and
grid_dim.x * index.y before converting to 64-bit, which can cause overflow with
large grid sizes. To fix this, cast index.x and grid_dim.x to the 64-bit integer
type (IdxT) before performing the addition and multiplication to ensure all
arithmetic is done in 64-bit and prevent overflow. Apply the same fix to the
similar code block at lines 39 to 42.
* Fix handling of Metal fused attn head dims (EricLBuehler#1234) * Fix handling of metal attn head dims * Fix handling of gemma3 1b when images * Tweak default for paged attn builder * Support paged attn for vision model rust api (EricLBuehler#1235) * [Breaking] Support setting HF cache path (EricLBuehler#1237) * Add it internally * Add the apis * Support tool calling for DeepSeek models (EricLBuehler#1239) * Support tool calling for deepseek models * Format * Fix deepseek * Server image processing refactor and fixes (EricLBuehler#1244) * Fix strict gemma3 case * Accept multiple images in the content array * Fix multiple images in one array ct * Add it to the python api * Typos * Optimized CUDA RoPE kernels (EricLBuehler#1247) * Add the kernels * It works * Works * Buulds * Typo fix (add_speial_tokens to add_special_tokens) (EricLBuehler#1246) * Fix typo * Update mistralrs.pyi * Fixes for UQFF + distributed layers (EricLBuehler#1250) * Fixes for uqff + distributed layers * Typo * Automatic agentic search integration (`web_search_options`) (EricLBuehler#1243) * Add the tool * Actually search * Clippy * Sort of works * Remove some debuggers * tweak * Add some rules * Works great * Tweak 'system' prompt * Update mistralrs-core/src/search/mod.rs Co-authored-by: Copilot <[email protected]> * Typo * Add it to all the apis * Add bert model for similarity reranking * Typos * Early detection of tools * Alias max_tokens -> max_completion_tokens too * Customizable bert model * Flip the enabler around * Add docs * Update readme * Typo --------- Co-authored-by: Copilot <[email protected]> * Format kernels (EricLBuehler#1251) * Update readme * Update readme * Remove test * Add quantize guards for uqff deserialize (EricLBuehler#1252) * Refactor cuBLASlt-related code (EricLBuehler#1253) * Centralize cublaslt into mistralrs-quant * Use cublaslt in unquant layer * Use beautiful trait constants for simpler code * Move tests * Dispatch to unquant for cublaslt * Dispatch to unquant for cublaslt * Fix feature * Add convert_to_gptq script * Update deps, bump pyo3 version (EricLBuehler#1259) * Faster cuda FP8 performance (EricLBuehler#1257) * Avoid fp8 sync * Fix dtype * Rust 1.86 clippy (EricLBuehler#1260) * Rust 1.86 clippy * Clippy * Refactor engine arch (EricLBuehler#1262) * Refactor engine add_request * Don't recompile regex * Clippy * Revamped LoRA support - removing the Ordering system! (EricLBuehler#1263) * Play with varbuilder lifetimes * Merge lora weights * Clippy * Lora works * Support multiple loras * Cleanup, remove adapter activation * Complete merge * Fast Metal-specific quantization method: AFQ (EricLBuehler#1264) * Add mlx quantized kernels * Add mlx quantized kernels * Kernel launcher * Add AFQ isq quant and dequant * Some quantmethod things * Begin to implement the qmm caller * Clippy * Much faster * Cache kernels * Docs * Clippy * Add it to uqff * Support prequantized models from MLX (EricLBuehler#1265) * Refactor quantizedconfig * Support AFQ prequantized * Update docs * Update docs * Automatic ISQ to select fastest & most accurate method (EricLBuehler#1266) * Automatic isq * typo * Doc * Improved usage metrics (EricLBuehler#1267) * Fix cuda * Bump tokio from 1.44.1 to 1.44.2 (EricLBuehler#1270) Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.44.1 to 1.44.2. - [Release notes](https://github.com/tokio-rs/tokio/releases) - [Commits](tokio-rs/tokio@tokio-1.44.1...tokio-1.44.2) --- updated-dependencies: - dependency-name: tokio dependency-version: 1.44.2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Gather MM ops in mistralrs-quant (EricLBuehler#1272) * Update the caller * Wire things up * Broadcase for afq gathermm * Broadcase for afq gathermm * Clippy * Improve performance of deepseek models * Typo fix * BincountOp not used * Implement Llama 4! (EricLBuehler#1268) * Implement Llama 4 * Implement the main changes for the text model * Make chunked mask * Wire things up * Add some EP * Initial sketch of inputs processor * Runs * Progress * all reduce moes * It works! * Some cleanup * Faster moe block * Add device map * Make chunked matrix * Fully working now! * Reactivate cublaslt * Fix shared mlp cublaslt * Refactor to packed experts * Complete merge * It is a normal model now * Fixes * Set device for moe * ISQ fixes * Much faster sort kernel * Faster loading! * Faster loading! * Fp8 cpu copy ops in candle backend * Add the vision model * Add mmproj layer * Actually merge the inputs * Sketch most of the image processor * Add the rest of the image processor * Implement the whole processor * Add the loader * Some fixes * A batch of fixes * Some fixes * tmp * Actually support isq * Ok it works a bit * Fix norm device * It works * A bit cleaner * Support residul tensors * Remove text loader * Implement the device mapping system * Fix auto device map * Add examples * Add model card * Typo * Remove superflous logging * Fixes for Llama 4 UQFF loading (EricLBuehler#1275) * Support sharding for UQFF (EricLBuehler#1276) * Serialize sharded uqff files * Loading * Fix base64 * Fix bug for group-topk (group_limited_greedy) in deepseek models (EricLBuehler#1278) * Support the DeepCoder model (EricLBuehler#1279) * Add faq for metal not found * Improved PagedAttn scheduling accuracy (EricLBuehler#1282) * Scheduler ops by reference * Ensure scheduler gets correct prompts * Fix cuda build for copy_blocks * Fixes for scheduling image seqs with pagedattn (EricLBuehler#1283) * update to llguidance 0.7.16 (EricLBuehler#1284) * update llguidance to 0.7.16 from crates.io; use ParserFactory * add lark_llg.py example * use new llguidance::Matcher APIs * rework spec-decoding with llg * more work on spec sampling * check for parser stop * fix clippy * remove unneeded rollback * update build_llg_factory to return Result * Update dependencies (EricLBuehler#1286) * Much faster image inputs processing (EricLBuehler#1289) * Add more SDPA head dims for much faster SigLIP (EricLBuehler#1290) * More sdpa head dims, faster vision models * Move nonzero to above for faster metal synch * Doc * Update valid head dims * Show throughput in interactive mode (EricLBuehler#1291) * Update interactive mode throughput stats * Accurate prompt t/s * Accurate prompt t/s for usage * Unify bitwise operations (EricLBuehler#1288) * Unify bitwise ops * Tests pass * Fix cuda build * Clippy * Multimodal prefix caching support! (EricLBuehler#1298) * Initial progress * Support vision prefix caching * Update docs * Add multimodal data abstraction * Interactive mode improvements (EricLBuehler#1299) * More ergonomic image url parsing * Add option to clear * Add the Qwen 3 and Qwen 3 MoE models! (EricLBuehler#1285) * Add qwen3 model * Add enable_thinking * Add initial qwen3 moe * Add the moe model * Format * Fix order of norm * Fix expert shapes * Fix reverse * Fix norm device for isq * Fix nonzero when no nonzero * Moe model runs * Working qwen3 moe * Add metal fp8 blockwise dequant * Clean * Typo * Enable tool calling * Streamlined ux * Add some examples * Add docs * Fix dead link * Remove interactive mode max_len * Update QWEN3.md * Hotfix for vision mode clear * Revamped and streaming web search support (EricLBuehler#1301) * Streaming web search * Refactor a bit * More refactoring * Add some logging, parallelize some things * Allow url * Suppress warning, allow multi-turn searching * Batch compute_similarities * Cap content len * Typos * Doc * Handle vision messages or different tool call prefixes (EricLBuehler#1302) * Fix cuda * Tune web search budget * Simplify prefix cacher (EricLBuehler#1305) * Use rustyline to handle non-ascii in interactive mode (EricLBuehler#1306) The io::stdin().read_line() cannot handle non-ascii input, which caused crash when use backspace to delete non-ascii characters. Introduce rustyline to the interactive mode to solve the problem. Plus it can bring more editing features in the future. Close EricLBuehler#1140 * Add more tools for automatic search (EricLBuehler#1307) * Add interactive mode history * Add a website extraction tool * Pass toks by reference * Optimize prompt chunking * Fix CPU hogging in interactive mode (EricLBuehler#1309) The log enabler should be checked after the sleep instead of a busy loop checking. Since the interactive mode always disables the token speed logger, 100% CPU was taken by this loop always. * Add Metal precompilation support (EricLBuehler#1311) * Add metal precompilation for paged attn * Add for mistralrs-quant * Better constructor * Dont always build * Fix name for paged attn rebuild * Reduce thrashing of Metal autorelease (EricLBuehler#1313) * Reduce calls to autorelease * Optimize clone_in_cache * Refactor float8 * make `AdapterPaths` and `LoraAdapterPaths` public (EricLBuehler#1314) Make `AdapterPaths` and `LoraAdapterPaths` public so `LocalModelPaths` can be constructed outside of `mistralrs-core`. * Refactor KV cache manager (EricLBuehler#1315) * Refactor kv cache * Refactor caches * Fix some overflows * Add `Audio` and `Speech` model categories (EricLBuehler#1317) * add `Audio` to `ModelCategory` * add `Speech` to `ModelCategory` * fix to go back to PartialEq having an exhaustiveness check * Remove has_conv2d from vision model API (EricLBuehler#1318) * Unified/automatic flash attention enabler (EricLBuehler#1319) * Remove from sdpa params * Fix errors * No warnings * Log * Clippy * Fix cublaslt 4d mask (EricLBuehler#1320) * Fix cublaslt 4d mask * Clippy * Keep caches on gpu * Qwen VL models fixes (EricLBuehler#1322) * Add some defaults * Fix * Fix one thing * 2.5 vl works * Use caching again * Fix v2 * Move index inside loop * Offset in ropeidx * Default support for vision prefix caching is false * Fixes for all vision models (EricLBuehler#1323) * Fix phi input processor? * Fix phi input processor * Handle no_prefix_cache from pipeline * Phi models confirmed 👍 * Fixed for phi inputs processors * Fixed for phi4 * Llama 3 confirmed 😀 * Mistral 3 confirmed 😃 * Idefics 2/3 fixes * Some fixes * Remove unsafety * Improved+faster LRU prefix cacher (EricLBuehler#1321) * Show TTFT * Use LRU prefix cacher * Faster prefix cacher * Inplace ISQ support and default to mmap (EricLBuehler#1277) * Initial impl of immediate isq * Immediate isq -> !loading_isq * Varbuiler utils always using mmap! * Log * Add for packed experts * Afq without copy * Clarify * Clippy * Apple immediate isq * Better logic for loading_isq * Support showing ttft * Rename * Shared quantize guard * Parallel progress bar * Parallel loading for progress bars * Actual ISQ support * Conditional parallelism for NiceProgressBar * Use conditional iterator * Warn once * Predicate for applying immediate isq * Allow parallel * Remove debug print * Remove debug print * Remove debug print * Fix typos (EricLBuehler#1329) * Fix Idefics 3 arch chat templating (EricLBuehler#1330) * Update inputs merger * Fix * Better warning * Better warning * Better warning * Nonzero ahead of time * No f32 * Clippy * Optimize get_logprobs * Fix packed experts * Update masking * Use Sdpa in idefics3 * QuantMethod in idefics3 vision * Remove a .contiguous * Remove two space from PR comment (EricLBuehler#1331) * Add automatic vision loader type (EricLBuehler#1332) * Add automatic vision loader * Remove references to --arch * Update examples * Add the Dia 1.6b TTS model! (EricLBuehler#1304) * Add loading * Add rope, mlp, most of attn * Add encoder + encoder layer, decoder layer forwards * Add decoder forwards * Add prepare_audio_prompt * prepare_generation mostly done * Add a proper dia kvcache * Add most of decoder_step * Add the sampler * Add the generation loop * Wire things up * Add speech pipeline * Fixes * Loads * Some fixes * f32 * Some progress * Ok it runs upto dac decoding * Add dac part loading * Loads and runs at least * Remove encodec * Debugging * Debugging * Huh * Complete merge * Interactive * Confirmed dac works at least * Looks like encoder works * Much progress * Hmm * Sampling * Almost there * Sampler * Sampler * Bf16 support * Response * Use it in interactive mode * Fix oneshot * Add openai api * Add openai api * Refactor loading * Use naive sdpa for inplace * Factor out * Clippy * Clippy * Config * Refactor config * Metal clippy * Fix t/s * ISQ support * Some fixes, nits * Fix cuda * Clippy * Inhibit cublaslt for cuda * Add server example * Add python example * Add rust api * Add docs * Update config.toml * Fix .pyi * Update readme * config.toml tweak * config.toml tweak * config.toml tweak * config.toml tweak * config.toml tweak * config.toml tweak * config.toml tweak * config.toml tweak * config.toml tweak * update `llguidance` to `0.7.20` (EricLBuehler#1334) Update `llguidance` from `0.7.16` to `0.7.20` so that it has guidance-ai/llguidance#172 which is a fix for building on GCC 15. * Add model category <> messages check (EricLBuehler#1335) * Verify model category matches the messages * Add vision chat * Fixes * Add element-wise normalization check (EricLBuehler#1340) * Fix streaming example print statement (EricLBuehler#1339) * Fix normalization formula in comment (EricLBuehler#1338) * Fix image_to_pixels to handle non-RGB images (EricLBuehler#1337) * Fix typo in expect messages (EricLBuehler#1342) * Don't use mmap on cuda (EricLBuehler#1336) * No mmap on cuda * Simplify streaming tool call logic * Remove debug * Support AWQ format models (EricLBuehler#1350) * Support AWQ format models * Clippy fix * Fix uqff dummy layer ISQ application (EricLBuehler#1351) * Disable immediate isq if write_uqff (EricLBuehler#1352) * Fixes for UQFF loading on CUDA, ISQ pack factor (EricLBuehler#1354) * Fix logic for uqff on cuda * Updated pack_factor * Refactor Option references for model paths (EricLBuehler#1347) * refactor: use Option refs in model path helpers * Format * Add a script for server benchmarking (EricLBuehler#1355) * Serde alias * Fix * Update for tie_word_embeddings * Print running/waiting * 30 users * Update num_users * Update dummy paged attn * Optimized Metal qmv_fast path (EricLBuehler#1356) * Compile with lto * Tweak profiles * New, fast sampler for Metal! (EricLBuehler#1327) * Show TTFT * Use LRU prefix cacher * Faster prefix cacher * A bit of gpu sampling * Minp but cpu for now * Metal fast cumsum impl * Sampling with fast topp kernel * Hmm not perfect * Add metal sort kernels * Tmp * Add single block sort * Add most of multi block sort, just need copy op * Add copy kernels * Expose kernels * Add a test * Ok it works * Structure things * Add caching * Rename * Cpu is default * CUDA case * Topk * Refactor Option references for model paths (EricLBuehler#1347) * refactor: use Option refs in model path helpers * Format * Add a script for server benchmarking (EricLBuehler#1355) * Serde alias * Fix * Update for tie_word_embeddings * Print running/waiting * 30 users * Update num_users * Update dummy paged attn * Optimized Metal qmv_fast path (EricLBuehler#1356) * Compile with lto * Tweak profiles * Fix topk * Penalties * Add logits processor, clippy fixes * Fix chat port * Remove warning * Fix chat port * Fix metal parallel sampling (EricLBuehler#1357) * Cpu if parallel for now * Tweak bench script * Add immediate isq predicates for qwen3 (EricLBuehler#1358) * Add immediate isq predicates for qwen3 * Fix parsing of "parse_isq_value" depedent of device * Typo * Fix gemma3 logging * Regressions fixes (EricLBuehler#1359) * Fix regression for mmap * Revert EricLBuehler#1321 * Refactored matching_cache impl * Clippy * Revamped and smaller readme (EricLBuehler#1360) * Expandable detail sections * Refactor using derivative model * Tweak quick examples * Update llama * Update llama * Supported accelerators is a table * Update installation guides * Tweak apis * Remove --port in quick examples * Add demo gif * Add gif in readme * Update demo gif * Update demo gif * Update demo gif * Add gif in readme * Add gif in readme * Add a web chat app! (EricLBuehler#1362) * Initial * Markdown * Copy code * Add model loading sidebar * Support vision models * Tweak isq * Links go to another page * Clear when switch model * Fix html tags * Add image support! * More then one images * Fix * Improved textarea * Tab for switching between vision and text * No paged attn for now * Prettier format * Multiple models at once * Better switching, clearing ability * Mobile support * Inline markdown parser * Update examples * Typos * Support specifying isq * Fix mobile * Fixes * Fix button on mobile * Image height is capped * Thumbnail * Fix rotating kv cache edge case * Add drag and drop for images * Small things * Sidebar is frozen now * Better listner * Add readme * Tweak readme * Add chat history support to web chat app (EricLBuehler#1363) * Add chat history * Support renaming * Start immediately with new chat * Add timestamp * Prettier chat list * Style * Delete chat * Fix copy button * Fix markdown rendering * Store things in cache * Store things in cache * Refactor web chat, fix multichat image restore (EricLBuehler#1364) * Fix multichat image restoration. * Clippy * Refactor * Refactor frontent * Fix repeated immediate isq init (EricLBuehler#1365) * Add images_ref * Add debug impl * Fix the bug * Tweak style of buttons * Add a spinner * Move spinner * Tweak emoji * Add gif * Tweak initial gif * Include vision tower tensors in Mistral3 UQFF (EricLBuehler#1366) * Fix mistral 3 uqff resitdual tensors for vision * Rolling shard creation for uqff files (EricLBuehler#1367) * Fix occasional unstability during isq of afq (EricLBuehler#1368) * Fix unstability during isq of afq * Clippy * Fix web chat installation * Support web chat file uploading (EricLBuehler#1370) * Web chat fixes * Fix thumbnail in message, reuse blank chat * Add file uploading support * Fix scroll * Allowed extensions * Preserve files as literals * Support multiple clients * Add a stop button * New cache dir * New cache dir * Fix * Refactor * Update readme * Tweak drag-and-drop css * Add speech generation support to the web chat! (EricLBuehler#1373) * Initial speech gen support for web chat * Tweak ui * Update docs * Prefix caching for PagedAttention! (EricLBuehler#1369) * Exposing some things for logical token blocks * Prefix cache manager has the scheduler * Refactor * Get logical and physical blocks into the prefix cacher * Hash and cache * Pass physical block prefill * Allocation of prefilled block tables * Temp * Dont always use 2 * Hmm * Hmm * It mostly works * Increment refcount * Support images! * Add to dummy paged attn * Fix some clippy * Clippy * More checks * Include EricLBuehler#1371, closes EricLBuehler#1371 * Typos * Update docs * Metal PagedAttention accuracy improvements (EricLBuehler#1374) * Fix subtle bug * Fix half sum bug * Format metal paged attention * Handle images in paged attn scheduler (EricLBuehler#1375) * Include schemas needed for chatcompletions endpoint (EricLBuehler#1353) * EricLBuehler#1326: WIP include schemas needed for chat completions endpoint Conflicts: Cargo.lock mistralrs-server/src/main.rs * EricLBuehler#1326: WIP define utoipa as a workspace dep since core and server both need it * EricLBuehler#1326: first draft of handling schemas that use Either * EricLBuehler#1326: first draft of handling schema for Grammar * EricLBuehler#1326: Add in other endpoints to API docs. * EricLBuehler#1326: Adjust code comments * EricLBuehler#1326: Implement coderabbitai suggestions - EricLBuehler#1353 (review) - EricLBuehler#1353 (comment) * Fix constraints with metal sampler * Revert EricLBuehler#1375 * Fix case where prefix cacher returns no toks (EricLBuehler#1377) * Fix AFQ UQFF serialization * Faster UQFF serialization (EricLBuehler#1379) * Faster UQFF serialization * Fix uqff gemma3 * Improve gemma3 auto loader names * UQFF creation for AFQ on CPU support (EricLBuehler#1380) * Add afq cpu quantize/dequantize * Clippy * Improved device for afq quantize * Improved dtype handling for cpu afq (de)quantize * Improved generate_uqff_card * Add fused CPU attention kernel! (EricLBuehler#1382) * Working * Fix warnings * Allow mask * Support bf16, f16 * Handle striding * Parallelized * Add initial vector flash attn * Avoid repeated allocations * Tiled kv * Apply some clippy * Some small fixes * Chunked vec_dot * Clipy * Use T::zero * Refactor attention backends (EricLBuehler#1384) * Refactor attention code * Refactor attention code * Move into backends * Set macOS thread affinity for CPU attn (EricLBuehler#1385) * Use lazylock * Format * Fix metal warn build * Faster Qwen 3 MoE support on Metal (EricLBuehler#1387) * Fix load * Use afq gather qmm * Well it runs * It works * Polish * Fast and slow options * Remove quantized.rs * Polish some more * Refactor * Add isq * Update load in parallel * Support fp8 * Refactor for FusedExperts * Clippy * Handle pack factor when loading prequantized models * Use f32 only in moe * Avoid using f32 so much * Avoid using f32 so much * Fix PagedAttention block leaks (EricLBuehler#1388) * Warn and ignore if ignored * Fix a block allocation leak * Update bench.py * Fix double free in block engine * Do not apply ISQ if loading a prequantized model * Fix cuda build again (EricLBuehler#1389) * Fix cuda build * Fix * Format * Fixes for cuda docker * Update dockerfiles * Bump version to 0.6.0 (EricLBuehler#1390) * Bump version to 0.6.0 * Remove lower_level api * Make a static dir * Update deps * Fix routing for static handler in web chat * Fewer .contiguous calls for qwen3 moe (EricLBuehler#1391) * Allow speech models to accept batched inputs (EricLBuehler#1393) * Allow speech models to accept batched inputs * Clippy * Ring distributed backend for heterogeneous TP (EricLBuehler#1238) * Begin work on ring distributed backend for Metal * Add the actual ring functionality * It loads and kind of runs * It works * Optimize buffer allocation * Avoid copy * It works * Add allgather * Fix load * Ping-pong * Small things * Add config json * Allow different ip address * Read config once * Read config when appropriate * Replicate requests * Small fix * Fix small compat with openai * Clippy * Update docs * Add deepseek tool calling chat template * Add auto loader for vision/text detection! (EricLBuehler#1402) * Add auto loader for vision/text detection * Build fixes * Add model loader * Update docs * Format * Create Mistral.rs Server Core Lib: `mistralrs-server-core` (EricLBuehler#1346) * First draft of exposing mistral server routes as lib * make arg struct fields pub * Take base path so utoipa swagger route can properly redirect * Expose swagger routes and make it configurable * Add base path option for swagger docs * More work on modularizing mistralrs server * Sync fork (+1 squashed commit) Squashed commits: [169ae9e] Sync fork * Adjust fn params to use refs / individual params instead of args * Start breaking down controller actions into smaller pieces * Continue refactoring * Make mods pub so they can be used outside crate * Allow chat completion streamer to take a callback so that you can get the complete response when finished WIP (+3 squashed commits) Squashed commits: [0061d87] WIP [c484d56] WIP [16f8a60] WIP * Sync fork * Adjust callback type * Remove throughput_log arg that was removed in 26afcc3 * Implement defaults for Args (and use for Clap) * Small code formatting tweaks * Rename callback to match SSE event and code clean up * Sync fork * WIP: first very rough draft of server core builder. Doesn't meet parity with old functional approach yet (slower / unstable?). * Clean up (+4 squashed commits) Squashed commits: [e1cff387] Sync fork [d8301025] WIP debugging [1ea9f8c8] Sync fork [4fe28cf5] WIP: debug function * WIP server core builders * Code clean up * Add on_chunk callback * Code clean up * First draft of creating version of mistral-server that uses server-core Code clean up (+1 squashed commit) Squashed commits: [adea1693] * Sync fork * Add helper methods to builder to make optional args more ergonomic (since .build validates params) * Start adding docs * Start cleaning up crates deps * Example commit of mistral-server with implementing server-core * Start addressing CodeRabbit feedback * Fix comment typo * Tweak doc blocks * - Update type alias naming for clarity (MistralRs instead of Mistral) - CodeRabbit, don't use eprintln for lib (use trace) - Allow buffer size to be passed in and default to Constant - Allow router body limit to be passed in and default to Constant - Update doc examples * Typo * Address CoderRabbitAI feedback * Support linear rope for llama3 (EricLBuehler#1408) * Hotfix for loading * Fix vllama4 uqff loading (EricLBuehler#1409) * Fix vllama4 uqff loading * Fix regex * Fix regex * Maybe a fix * Gracefully handle receiver disconnects (EricLBuehler#1410) * Handle receiver disconnects * Format * Fix Qwen3 MoE device mapping irregularities (EricLBuehler#1411) * Fix bias * Fix lm_head packing case * Account for gate * Fix head dim * Fix interactive mode URL parsing (EricLBuehler#1412) * fix url regex in vision interactive mode * Fix regex * Clippy * Refactor auto device map (EricLBuehler#1413) * Refactor auto device map * Refactor a bit more * Clippy * Enable runtime sampling tweaks in interactive mode (EricLBuehler#1414) * Document runtime sampling commands * Fix readme * Tweak * Bounds checking * Tweak temp bounds * Send streaming tokens every time * Gumbel sampling for fast sampler (EricLBuehler#1416) * Improved handling for initialize_logging * Improved CPU flash attention accuracy & performance (EricLBuehler#1417) * Downcast correctly * Operate internally in f32 * Avoid some casts and striding * Prefetch * Provide chat_templates to container users (EricLBuehler#1419) Models often come without chat templates requiring mapping them from the source repository into a container for access by the mistralrs-server. Copy the templates from the build tree into the root of the image to permit use via `--chat-template /chat_templates/something.json` TODO: With the increase in quantized models and support for other formats, the initial benchmark run during model load can be used to qualify/select existing chat templates embedded into the binary for models which do not come with any (to include output of the functional failures in each test allowing users to modify the ones already provided correctly to suit the model being loaded). Co-authored-by: RageLtMan <rageltman [at] sempervictus> * Faster cpu flash attn (EricLBuehler#1418) * Faster cpu flash attn * Prefetch * Clippy * Add some tests * Add softcap tests * Fix test_parse_image_url test * Update tests * Update tests * Web search improvements (bm25, web chat) (EricLBuehler#1420) * Fix web search blocking case * Web search support in web chat * Tweak ui * Support fallback to bm25 * Clippy * Reinject descriptions * Propely handle consecutive searches (EricLBuehler#1421) * Update extraction tool reinjection * Looped * Update docs (EricLBuehler#1422) - lib.rs: clean up example var names and match logging change from EricLBuehler@201d6be - server_builder: fix typo - READMEs: link to crate docs * Better tool call detection logic (EricLBuehler#1424) * Add web search hook callbacks (EricLBuehler#1426) * feat: add customizable search hook * Move to builder * Update docs * Fix CUDA context switching, bind thread on CudaStorage drop (EricLBuehler#1428) * Add CUDA context helper and use in Llama forward * No flashparams? * working * Tweak * Update to use dep * conditionally build flash attention inputs (EricLBuehler#1429) * Add AGENTS.md (EricLBuehler#1430) * Support Qwen3 GGUF model (EricLBuehler#1432) * Support QWen3 GGUF model * Clippy fix * cargo fmt * Improved paged attn prefix caching (EricLBuehler#1434) * Improved paged attn prefix caching * Disable * Clippy * Temporary fix for qwen3 gguf tokenizer (EricLBuehler#1433) * Temporary fix for qwen3 gguf tokenizer * Typo fix * Add tool callback support (EricLBuehler#1427) * Add tool callback support * Fixes * Support named tool callbacks * Update examples * Update docs * Clippy * Centralize crate dependencies (EricLBuehler#1438) * chore: centralize dependencies * Format * Fix bug in tokenizer created with gguf metadata (EricLBuehler#1440) * Fix bug in tokenizer created with gguf metadata * Clippy fix * Update deps (EricLBuehler#1441) * Small things * Update deps * Update deps * Update breaking changes * Doc fixes (EricLBuehler#1442) * Mention uqff_maker * Downgrade rustyline 16.0.0 -> 15.0.0 (EricLBuehler#1444) * Add max_completion_tokens alias for server (EricLBuehler#1451) * Audio input support (Phi 4 multimodal) (EricLBuehler#1448) * Deps * Add conformer * Nemo loading * Position embeds * Load t5 attn bias * Attn and feed forward * Add conv module and glu pointwise * Implement relative attn bias * Add the forward methods * Add encoder embedding * Fix oproj * Some loading * Conformer loads! * Fully loading speech stack * Merger * Dont need that * First pass at audio processing * Read samples * Optional * Small loading fix * Runs but not correct yet * Improved audio processing? * Works with this * Fix t5 attn bias * It works! * Comment * Use some other crates * Clippy * Allow bf16 on metal * Add prefix_audio * Remove unused * Typo * User specified * Add audio url parsing * AudioProjectionMode -> InputMode * Audio prefix caching * Fix bug in audio prefix caching * Support both at the same time! * Tweak logging * Support stereo * Add mistralrs-audio * Support batching * Add server and rust api example * Add python api * Fix add_multimodal_message * Fix unfold for conformer * Streaming example * Add web chat support * Add modalities registry * Fix offline cache issue for gguf models (EricLBuehler#1452) * Add MCP server endpoints (EricLBuehler#1453) * feat(server): add MCP server support * Add mcp docs * Add handle_list_tools_request * Better launch, tool handling * Tmp state * Ok works * Handle modalities * Update docs * Add ping * Tweak temperature bounds, args * MCP documentation pass (EricLBuehler#1455) * Fix table * Update mcp docs * Improve readme header * Improve readme header * Integrate an MCP client (EricLBuehler#1456) * Add builtin mcp client * Use async loader * Add headers * Handle sse * More flexible search request * Add tool callbacks with tools, for mcp * Add bearer token support * Add websocket support * Update docs * Add python api * Clippy * Add http api, docs * Tests pass * Make these configs actually work * Add docs * Make mistralrs-mcp * Refactor examples * Update examples * Add defaults * Add defaults * Add defaults * Update docs * Improved docs * Add -y to npx usages * Even better examples * Update generate_wheels * Update generate_wheels * Update generate_wheels * Fix Dockerfile.cuda-all * Improve automatic tool call (EricLBuehler#1460) * Improved auto tool call * Add logging * chore: `Dockerfile.cuda-all` configurable threads (EricLBuehler#1458) * chore: `Dockerfile.cuda-all` - Merge `RUN` for `apt-get install` (EricLBuehler#1459) * Add fallback definition for isnan (EricLBuehler#1463) * chore: `Dockerfile` - Drop runtime rayon thread ENV (EricLBuehler#1465) * chore: Dockerfile - Remove rayon threads env * chore: Dockerfile - Improve formatting for `apt-get` * Remove duplicate calls for api_dir_list (EricLBuehler#1474) * Remove duplicate calls for api_dir_list * Support local cache for api_dir_list * Fix home folder for metal * Capitalized * Fix transient pyo3 dep (EricLBuehler#1478) Co-authored-by: Eric Buehler <[email protected]> * Fix objc dep with non macos (EricLBuehler#1480) * Fix phi 3/4 + nccl issue (EricLBuehler#1481) * Fix log * Fix n kv heads * Fix phi3.5 moe (EricLBuehler#1482) * Fix phi3.5 moe accum device * Fix again * Fix again * Support GLM4 model! (EricLBuehler#1437) * Support GLM4 model * Mention GLM4 model in ReadMe * glm4 type hint * Typo fix * Fix unsupported chat_template function * Clippy fix * Refactor distributed backend (EricLBuehler#1484) * Refactor distributed backend, check power of 2 * Fix compilation * Cap metal paged attn kv allocation (EricLBuehler#1485) * Better paged attn metal cap (EricLBuehler#1486) * Better paged attn metal cap * Small fix * Comment * Small fix * Refactor * Server core: consolidate and unify route handlers and API surface (EricLBuehler#1423) * Start working on consolidating completion and chat_completion underlying implementations * Move response channel to util mod for now (since it's used with streaming and non streaming) * More work on consolidating completions and chat completions * More WIP consolidation of server core handlers * More WIP consolidation of server core handlers * More WIP consolidation of server core handlers * Update docs and restrict completion core visibility * CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this * Use consistent var name for completions mod * Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub Make lib.rs example compile checked and update example * Code formatting * Typo * Sync fork * Sync fork * Docs example fix * Support qwen3 gguf (EricLBuehler#1488) * Add qwen3 gguf * Template fixup * Make bos/eos token IDs optional (EricLBuehler#1493) * Remove python deps from CUDA dockerfiles (EricLBuehler#1487) * Handle noncontiguous v in naive_sdpa (EricLBuehler#1499) Co-authored-by: Eric Buehler <[email protected]> * Server Core: refactor Paged Attention configuration (EricLBuehler#1500) * Use StorageModePrivate for Metal PA kv cache (EricLBuehler#1506) * Fix OpenAI stream: emit field in tool-call deltas for schema compliance (EricLBuehler#1507) * FP8 KV-cache quantization for PagedAttention (EricLBuehler#1400) * Add most of paged attn kv quant * It builds a bit * All the functionality at least * Small fix * Add a scale * Fix bf16 usage * Make k_v_scale optional * Collector * Tweak collection * Refactor * Add to apis * Add cuda impl * Fix compilation * Fixes * Handle ENABLE_FP8 * Format * Tweak * Fix scaled_convert usage * Fix cache_t size * Fixed scale collection * Actual fix * Fix fp8 for CC<8 * Fix the usual String != &str bit (EricLBuehler#1483) Co-authored-by: RageLtMan <rageltman [at] sempervictus> * chore: `Dockerfile` - Drop runtime rayon thread ENV (EricLBuehler#1465) * chore: Dockerfile - Remove rayon threads env * chore: Dockerfile - Improve formatting for `apt-get` * Remove duplicate calls for api_dir_list (EricLBuehler#1474) * Remove duplicate calls for api_dir_list * Support local cache for api_dir_list * Fix home folder for metal * Capitalized * Fix transient pyo3 dep (EricLBuehler#1478) Co-authored-by: Eric Buehler <[email protected]> * Fix objc dep with non macos (EricLBuehler#1480) * Fix phi 3/4 + nccl issue (EricLBuehler#1481) * Fix log * Fix n kv heads * Fix phi3.5 moe (EricLBuehler#1482) * Fix phi3.5 moe accum device * Fix again * Fix again * Support GLM4 model! (EricLBuehler#1437) * Support GLM4 model * Mention GLM4 model in ReadMe * glm4 type hint * Typo fix * Fix unsupported chat_template function * Clippy fix * Refactor distributed backend (EricLBuehler#1484) * Refactor distributed backend, check power of 2 * Fix compilation * Cap metal paged attn kv allocation (EricLBuehler#1485) * Better paged attn metal cap (EricLBuehler#1486) * Better paged attn metal cap * Small fix * Comment * Small fix * Refactor * Server core: consolidate and unify route handlers and API surface (EricLBuehler#1423) * Start working on consolidating completion and chat_completion underlying implementations * Move response channel to util mod for now (since it's used with streaming and non streaming) * More work on consolidating completions and chat completions * More WIP consolidation of server core handlers * More WIP consolidation of server core handlers * More WIP consolidation of server core handlers * Update docs and restrict completion core visibility * CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this * Use consistent var name for completions mod * Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub Make lib.rs example compile checked and update example * Code formatting * Typo * Sync fork * Sync fork * Docs example fix * Support qwen3 gguf (EricLBuehler#1488) * Add qwen3 gguf * Template fixup * Make bos/eos token IDs optional (EricLBuehler#1493) * Remove python deps from CUDA dockerfiles (EricLBuehler#1487) * Handle USE_FP8 for cuda * Fix cuda warn * Add readme * Saturating sub in sequence state --------- Co-authored-by: Eric Buehler <[email protected]> Co-authored-by: RageLtMan <[email protected]> Co-authored-by: Brennan Kinney <[email protected]> Co-authored-by: Guoqing Bao <[email protected]> Co-authored-by: Matthew Haynes <[email protected]> * Validate model name in OpenAI API (EricLBuehler#1509) * Validate model name in openai api * Add docs, allow 'ignore' * Updated examples for EricLBuehler#1509 * Fix mcp import in doc string (EricLBuehler#1510) * Add multi-model support! (EricLBuehler#1512) * Refactor MistralRs * Working multi-model! * Add mutli-model docs initially * Update mistralrs-pyo3, mistralrs-bench, mistralrs * Update apis for consistency * API tweaks * Logging tweaks * Add examples, tweak cli * Clearer pipeline id * Fix config key semantics * Format and clippy * Tweak logging, fix example * Clippy refactor * Update examples * Remove unused multi model docs * Replace 'ignore' with 'default' * Update docs * Add stars label to readme (EricLBuehler#1513) * Add CLAUDE.md * Handle base_model.model case in lora (EricLBuehler#1514) * Add thread_local! for engine-specific const/static (EricLBuehler#1517) * Fix MCP doc test (EricLBuehler#1511) * Allow disabling metal precompilation (EricLBuehler#1518) * Allow disabling metal precompilation * Simple preprocessor * Simple docs --------- Co-authored-by: Eric Buehler <[email protected]> * Rust 1.88 clippy (EricLBuehler#1522) * Rust 1.88 clippy * Format * Fix cuda warnings (EricLBuehler#1526) * Avoid panic decoding tokens on error (EricLBuehler#1527) * Split Marlin and Paged Attention kernels for faster build (EricLBuehler#1525) * Split Marlin and Paged Attention kernels for faster build * Typo fix * chore: update llguidance (EricLBuehler#1535) * chore: update llguidance * chore: remove unused import * Add the SmolLM3 model! (EricLBuehler#1501) * Add model * Update loader * Fix llama config usage * Docs * Fix config no_rope_layers * Fix tie_word_embeddings default * Add chat template * Embed the chat templates * Fix embedding template * enable_thinking default true * Update examples * XML tools for smollm3 * Add smollm3 docs * Fix openai examples * Clippy --------- Co-authored-by: Eric Buehler <[email protected]> * Add full Gemma 3n support! (EricLBuehler#1519) * Add initial * Loading for text model * Add ple embeddings * Add altup, laurel block * Update rmsnorm * Add mlp * Update attn norm application * Currently no kv shared * Wire it up * It runs * Fix bf16 * Fix scaled embd * Fixes for mean * tmp * Attn confirmed * Fix target_magnitude * Add shared kv * Ok it works * Remove npy * Fix streaming * Remove warnings * Remove paged attn * Refactor rope * Add immediate isq * Add vision & mproj * Update image processor * Vision merge runs, not correct * Remove * Add mobilenet v5 * Add multimodal vision embedding * Fix load * runs * Fix gamma * Works but just not vision tower * It works!! * Tweak * Fix warnings * Move vision tower * Fix warn * Update cache manager things * Refactor * Add audio model, it loads * Add audio processing * It runs at least * tmp * A bit better * Audio works!!!! * Fused attn in vision * Clippy * Update audio runner * Optimized audio model * Remove unused things * Fix inputs processor bug * Remove comments * Clippy * Small optimizations * Format * Correctly register modalities * Add docs * Update readme * Runs there * Fixed padding from Blaizzy/mlx-vlm#410 * Add better checks * Fix sdpa n_kv_groups * Vision encoder works! * Rotate image * Clippy * Fix cuda loading * Updated device mapper * Fix overflow * Fix dtype errors * Refactor image/audio embeddings * Fix metal * Fix dtype mismatch * Audio processing fixes * Audio processing fixes * Works * Audio is good * Fix boi/eoi too * Embed the chat templates * Better embedding accuracy in non f32 * More f32 * Support bf16 on metal * Add more ISQ * Fixed device map * Clippy * Gemma3n no paged attn * Fix saturating sub * Faster rmsnorm * Use sdpa for vision model * Fix ple bug * Fix name * Fix multiaudio * Add matformer config loading * Add docs * Add support for matformer in auto device mapper * Update docs * Typos * Tweak * Tweak * Fix multidevice * Fix gemma3n text model auto device map * Fix dims3 * Fix auto devic emap vision * Non-metal keeps PLE on cpu * Complete merge * Vision dtype f16 -> f32 * Fix metal nm device * Fix uqff * Typos * Reference uqff * Fix tests * Fix sequence length check (EricLBuehler#1546) * update candle version (EricLBuehler#1545) Co-authored-by: AlpineVibrations <[email protected]> * add ios target to metal deps (EricLBuehler#1548) --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: Eric Buehler <[email protected]> Co-authored-by: Eric Buehler <[email protected]> Co-authored-by: edwko <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Guoqing Bao <[email protected]> Co-authored-by: Michał Moskal <[email protected]> Co-authored-by: Chen Mulong <[email protected]> Co-authored-by: Steph Wolski <[email protected]> Co-authored-by: omahs <[email protected]> Co-authored-by: Viktor Szépe <[email protected]> Co-authored-by: Matthew Haynes <[email protected]> Co-authored-by: RageLtMan <[email protected]> Co-authored-by: Brennan Kinney <[email protected]> Co-authored-by: Eric Buehler <[email protected]> Co-authored-by: Sbargaoui <[email protected]> Co-authored-by: Gaétan Lepage <[email protected]> Co-authored-by: Ammar Elsabe <[email protected]> Co-authored-by: luke <[email protected]> Co-authored-by: AlpineVibrations <[email protected]> Co-authored-by: Michael Tissen <[email protected]>
Top k
Top p
Min p
Frequency penalty
Presence Penalty
(?) DRY penalty
Metal argsort
CUDA argsort
CPU
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Style
Documentation