Skip to content

Fast sampler #1327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
May 21, 2025
Merged

Fast sampler #1327

merged 33 commits into from
May 21, 2025

Conversation

EricLBuehler
Copy link
Owner

@EricLBuehler EricLBuehler commented May 11, 2025

  • Top k

  • Top p

  • Min p

  • Frequency penalty

  • Presence Penalty

  • (?) DRY penalty

  • Metal argsort

    • single block sort
    • multi block sort
  • CUDA argsort

    • Use CUB
  • CPU

Summary by CodeRabbit

  • New Features

    • Added fast, GPU-accelerated sorting and cumulative sum (scan) operations for tensors on Metal devices, enabling high-performance argsort, sort, and cumsum along arbitrary axes.
    • Introduced benchmarking script for load testing language model servers and reporting throughput statistics.
  • Improvements

    • Enhanced quantization support: quantization configuration is now recognized using both "quantization_config" and "quantization" keys across all supported models, and is properly applied to language model head layers.
    • Scheduler and engine now provide real-time logging of running and waiting sequence counts for improved monitoring.
    • Optimized sampling and tensor device handling for improved performance on Metal backends.
    • Expanded and refactored Metal GPU kernel suite to support efficient scan, sort, and copy operations for a wide range of data types.
  • Bug Fixes

    • Improved option handling and deserialization for quantization configuration fields in model configs.
  • Style

    • Code formatting and style improvements in CUDA and Python scripts for better readability and consistency.
  • Documentation

    • Enhanced inline documentation and test coverage for new tensor operations.

Copy link

coderabbitai bot commented May 11, 2025

Walkthrough

This update introduces extensive improvements to the quantization and Metal backend infrastructure, including new GPU kernels for scan, sort, and copy operations, and exposes new fast tensor-based sorting and cumulative sum operations at the Rust level. The model initialization logic is refactored to propagate quantization configuration to language model head layers, and serde aliasing is added for quantization fields across numerous model configs. Scheduler interfaces are updated to integrate interval logging of running and waiting sequences. Additional code style and refactoring changes improve clarity, consistency, and maintainability throughout the codebase.

Changes

File(s) Change Summary
Cargo.toml Adds [profile.release] and [profile.dev] sections with LTO and optimization settings.
examples/server/chat.py Changes OpenAI client base URL from port 1234 to 8080.
scripts/bench.py Adds a new async benchmarking script for load testing a language model server.
scripts/convert_awq_marlin.py Code style and formatting improvements; no logic changes.
mistralrs-core/src/engine/logger.rs Adds atomic counters and setter methods for running/waiting sequence logging in IntervalLogger.
mistralrs-core/src/engine/mod.rs Makes IntervalLogger public and updates scheduler call to accept a logger reference.
mistralrs-core/src/dummy_paged_attention/scheduler.rs, mistralrs-core/src/paged_attention/scheduler.rs, mistralrs-core/src/scheduler/default_scheduler.rs, mistralrs-core/src/scheduler/mod.rs Updates all scheduler schedule methods and trait to accept a logger reference and log running/waiting counts.
mistralrs-core/src/kv_cache/rotating_cache.rs, mistralrs-core/src/kv_cache/single_cache.rs Changes all_data methods to return Option<&Tensor> instead of &Option<Tensor>.
mistralrs-core/src/models/* (multiple files) Adds serde alias "quantization" to quantization config fields; propagates quantization config to lm_head layer construction.
mistralrs-core/src/vision_models/* (multiple files) Adds serde alias "quantization" to quantization config fields and passes quantization config to lm_head layer.
mistralrs-core/src/pipeline/mod.rs Removes to_device from ForwardInputsResult and simplifies logits collection logic.
mistralrs-core/src/pipeline/ggml.rs, mistralrs-core/src/pipeline/gguf.rs, mistralrs-core/src/pipeline/normal.rs, mistralrs-core/src/pipeline/vision.rs Refactors chat template path extraction to use local variables for clarity.
mistralrs-core/src/pipeline/macros.rs Updates macros to pass Option<&T> instead of &Option<T> for optional parameters.
mistralrs-core/src/pipeline/paths.rs Changes function signatures to use Option<&T> for optional parameters and updates internal logic accordingly.
mistralrs-core/src/sampler.rs Adds sample_fast method for fast tensor-based sampling; uses it under the metal feature.
mistralrs-quant/build.rs Expands Metal source/header file list and updates build script to handle all headers in compilation steps.
mistralrs-quant/src/lib.rs, mistralrs-quant/src/utils/mod.rs Re-exports new operators: CumSumOp and SortOp.
mistralrs-quant/src/utils/ops.rs Adds fast tensor-based Sort, ArgSort, and CumSum operations with Metal backend support and tests.
mistralrs-quant/src/safetensors.rs Simplifies dtype handling in tensor loading logic.
mistralrs-quant/src/metal_kernels/utils.rs Adds helpers for grid and block dimension calculations for Metal kernels.
mistralrs-quant/src/metal_kernels/utils.metal Adds extensive numeric, indexing, and SIMD utility templates for Metal kernels.
mistralrs-quant/src/metal_kernels/bitwise.metal Adjusts parameter indentation for style consistency.
mistralrs-quant/src/metal_kernels/quantized.metal Removes unused utility templates and optimizes pointer arithmetic in qmv_fast_impl.
mistralrs-quant/src/metal_kernels/bf16.metal New: Adds bfloat16 type support and arithmetic for Metal kernels.
mistralrs-quant/src/metal_kernels/scan.metal, mistralrs-quant/src/metal_kernels/scan_impl.metal New: Adds GPU scan (prefix sum/product/max/min/logaddexp) kernel instantiations and implementations for Metal.
mistralrs-quant/src/metal_kernels/sort.metal, mistralrs-quant/src/metal_kernels/sort_impl.metal New: Adds GPU block and multi-block sorting kernel instantiations and implementations for Metal.
mistralrs-quant/src/metal_kernels/copy.metal, mistralrs-quant/src/metal_kernels/copy_impl.metal New: Adds comprehensive GPU copy kernel instantiations and implementations for all type combinations and dimensionalities.
mistralrs-quant/src/metal_kernels/mod.rs Adds Rust-side dispatch logic for Metal scan and sort kernels, including caching and kernel launch helpers.
mistralrs-quant/kernels/marlin/marlin_kernel.cu Code style and formatting improvements; no logic changes.

Sequence Diagram(s)

sequenceDiagram
    participant Scheduler
    participant Logger
    participant Engine

    Engine->>Scheduler: schedule(&Logger)
    Scheduler->>Logger: set_num_running(running_count)
    Scheduler->>Logger: set_num_waiting(waiting_count)
    Scheduler-->>Engine: SchedulerOutput
Loading
sequenceDiagram
    participant Tensor
    participant MetalBackend
    participant User

    User->>Tensor: .fast_sort_asc(axis)
    Tensor->>MetalBackend: launch sort kernel
    MetalBackend-->>Tensor: sorted tensor
    Tensor-->>User: sorted tensor
Loading
sequenceDiagram
    participant Tensor
    participant MetalBackend
    participant User

    User->>Tensor: .fast_cumsum(axis)
    Tensor->>MetalBackend: launch scan kernel
    MetalBackend-->>Tensor: cumsum tensor
    Tensor-->>User: cumsum tensor
Loading

Poem

In fields of code where kernels grow,
Metal bunnies sort and scan below.
Quantization configs hop with glee,
Through serde aliases, wild and free!
Schedulers now log their queue,
And sampling's faster, thanks to you.
🐇✨—the code hops on, renewed!

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@EricLBuehler EricLBuehler marked this pull request as draft May 11, 2025 13:45
Copy link

github-actions bot commented May 11, 2025

Code Metrics Report
===============================================================================
 Language            Files        Lines         Code     Comments       Blanks
===============================================================================
 C Header                3           62           53            0            9
 Dockerfile              1           41           22           10            9
 JSON                   12          107          106            0            1
 Makefile                1            6            5            0            1
 Python                 86         4042         3413          156          473
 Shell                   1           63           26           18           19
 Plain Text              3         3723            0         2413         1310
 TOML                   19          565          518            6           41
 YAML                    2           21           19            2            0
-------------------------------------------------------------------------------
 Jupyter Notebooks       3            0            0            0            0
 |- Markdown             2           77           32           31           14
 |- Python               2          205          178            1           26
 (Total)                            282          210           32           40
-------------------------------------------------------------------------------
 Markdown               55         5012            0         3822         1190
 |- BASH                 8          104          101            0            3
 |- JSON                 1           12           12            0            0
 |- Python               7          121          109            0           12
 |- Rust                22          757          634            1          122
 |- TOML                 2           75           63            0           12
 (Total)                           6081          919         3823         1339
-------------------------------------------------------------------------------
 Rust                  378       126908       113288         2593        11027
 |- Markdown           171         2145           29         1913          203
 (Total)                         129053       113317         4506        11230
===============================================================================
 Total                 564       140550       117450         9020        14080
===============================================================================

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

🧹 Nitpick comments (11)
mistralrs-server/src/interactive_mode.rs (1)

242-245: Start TTFT timer before send for accuracy (optional)

The timer is initialised after the request is sent.
If queueing on the Tokio channel becomes significant, TTFT will be underestimated.
Consider moving let start_ttft = Instant::now(); immediately before sender.send(req).await?;.

mistralrs-quant/src/metal_kernels/mod.rs (1)

1176-1199: Unused variable _bm and potential overflow in grid calc

let _bm = 32; is never used – remove or employ it.

Additionally, multiplying tmp_grid_dims.{width|height} by stride_blocks can overflow u64
for very large tensors. Use checked_mul and fall back to the alternate dimension on overflow.

mistralrs-core/src/sampler.rs (1)

353-366: Top-K threshold broadcasting may mis-align for batched logits

unsqueeze(0) produces a shape [1, vocab].
For batched inputs [batch, vocab] this works, but for a higher-rank tensor the broadcast is ambiguous.
Use unsqueeze(D::Minus1) (or keep the original dim count via expand) to guarantee alignment with the last dimension irrespective of rank.

mistralrs-quant/src/utils/ops.rs (2)

1011-1020: Wrong error message and unreachable match arm

The fallback arm returns
Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum"))
even for many other dtypes, making debugging harder.

Replace with the actual s1.dtype():

-            _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
+            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "cumsum")),

904-909: Naming & visibility

CumSum is private but the trait CumSumOp relies on users calling fast_cumsum[_config].
Consider making the struct pub(crate) and prefixing internal helpers with _ to clarify intent.

mistralrs-quant/src/metal_kernels/bf16.metal (1)

10-13: Duplicate typedef of bfloat16_t

Lines 10 and 12 both declare typedef bfloat bfloat16_t;.
Remove the second to avoid “redefinition” warnings.

mistralrs-quant/src/metal_kernels/scan.metal (1)

52-64: Macro explosion may hit Metal compiler limits.

The 𝟠×𝟡 explicit instantiate_scan_helper(...) calls generate 70+ specialized kernels per op (sum/prod/…) and per layout.
Several Apple-silicon tool-chains start failing once the number of functions in a single .metal file approaches ~1 000 because the internal IR hits the 64 Ki symbol table limit.
Consider splitting large groups into separate translation units or gate‐instantiate with #ifdefs that match the actual runtime usage set to keep compile time and binary size reasonable.

mistralrs-core/src/prefix_cacher.rs (1)

110-152: Potentially expensive full-vector clone in hot path.

seq.normal_cache().to_vec() and seq.image_hashes().map(|v| v.to_vec()) allocate and copy on every call even when the cache entry for the sequence already exists.
You can avoid the allocation for updates by checking if nb.cache.is_none() before cloning:

let (data, img_hashes) = if nb.cache.is_none() {
    (seq.normal_cache().to_vec(),
     seq.image_hashes().map(|v| v.to_vec()))
} else {
    (Vec::new(), None) // nothing needed – will be overwritten below
};
mistralrs-quant/src/metal_kernels/utils.metal (2)

12-16: Duplicate typedef – remove the extra declaration.

typedef bfloat bfloat16_t; appears twice back-to-back which may trigger a
“redefinition” warning on stricter Metal compilers.

-typedef bfloat bfloat16_t;
-typedef bfloat bfloat16_t;
+typedef bfloat bfloat16_t;  // single definition is enough

1151-1163: Logical functors return T but evaluate to bool.

LogicalAnd / LogicalOr compute x && y / x || y yet return a value of
type T. For non-boolean T (e.g. float, int, half) this relies on
implicit conversion from bool and silently narrows the result to 0 or 1.
Returning bool clarifies semantics and avoids accidental use in arithmetic
code where a full-precision T is expected.

-template <typename T> T operator()(T x, T y) { return x && y; }
+template <typename T> bool operator()(T x, T y) { return x && y; }

Apply similarly to LogicalOr.

mistralrs-quant/src/metal_kernels/scan_impl.metal (1)

121-140: CumLogaddexp needs numerical stabilisation

LogAddExp for large negative inputs can underflow to -∞, breaking the
scan when the tensor contains a wide dynamic range. Consider the classic
stable implementation:

U operator()(U a, T b) {
    U m = max(a, static_cast<U>(b));
    return m + log(exp(a - m) + exp(static_cast<U>(b) - m));
}

If LogAddExp{} already implements this, document it; otherwise replace the
call with a numerically stable version.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0b540ea and 089b617.

⛔ Files ignored due to path filters (1)
  • Cargo.lock is excluded by !**/*.lock
📒 Files selected for processing (16)
  • mistralrs-core/Cargo.toml (1 hunks)
  • mistralrs-core/src/pipeline/mod.rs (1 hunks)
  • mistralrs-core/src/prefix_cacher.rs (2 hunks)
  • mistralrs-core/src/sampler.rs (4 hunks)
  • mistralrs-quant/build.rs (3 hunks)
  • mistralrs-quant/src/lib.rs (1 hunks)
  • mistralrs-quant/src/metal_kernels/bf16.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/mod.rs (2 hunks)
  • mistralrs-quant/src/metal_kernels/quantized.metal (0 hunks)
  • mistralrs-quant/src/metal_kernels/scan.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/scan_impl.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/utils.metal (2 hunks)
  • mistralrs-quant/src/metal_kernels/utils.rs (1 hunks)
  • mistralrs-quant/src/utils/mod.rs (1 hunks)
  • mistralrs-quant/src/utils/ops.rs (2 hunks)
  • mistralrs-server/src/interactive_mode.rs (6 hunks)
💤 Files with no reviewable changes (1)
  • mistralrs-quant/src/metal_kernels/quantized.metal
⏰ Context from checks skipped due to timeout of 90000ms (8)
  • GitHub Check: Test Suite (macOS-latest, stable)
  • GitHub Check: Docs
  • GitHub Check: Test Suite (ubuntu-latest, stable)
  • GitHub Check: Test Suite (windows-latest, stable)
  • GitHub Check: Check (macOS-latest, stable)
  • GitHub Check: Clippy
  • GitHub Check: Check (windows-latest, stable)
  • GitHub Check: Check (ubuntu-latest, stable)
🔇 Additional comments (7)
mistralrs-core/Cargo.toml (1)

92-94: Good additions for the new radix trie implementation.

The three new dependencies added (hashbrown, parking_lot, and ahash) are appropriate choices for implementing the concurrent radix trie with LRU eviction mechanism. These libraries are well-maintained, provide good performance characteristics, and are commonly used together for high-performance concurrent data structures.

mistralrs-quant/src/utils/mod.rs (1)

8-8: Appropriate public export for the new CumSumOp trait.

The CumSumOp trait is correctly added to the list of publicly exposed operations from the module. This change aligns with the module's existing pattern of re-exporting operation traits.

mistralrs-quant/src/lib.rs (1)

59-59: Consistent crate-level re-export of the new operation.

The CumSumOp trait is properly added to the crate-level public exports alongside other operation traits, making it accessible to users of the library without requiring them to import from specific submodules.

mistralrs-core/src/pipeline/mod.rs (1)

423-427:

❓ Verification inconclusive

Simplified logits processing by removing explicit CPU device conversion.

The code has been simplified by removing the .to_device(&Device::Cpu) call and the associated error handling, which may improve performance by avoiding unnecessary device transfers. However, there's an inconsistency since the similar code in the PagedAttention branch (lines 564-570) still performs this conversion.

Please verify that operations downstream don't require the tensors to be on the CPU device. The inconsistency between this branch and the PagedAttention branch suggests there might be specific scenarios where device placement matters.


🏁 Script executed:

#!/bin/bash
# Find where the logits are used in sample_causal_gen to check if they require CPU device
rg -p "fn sample_causal_gen.*?\{\s*.*?\}" --multiline mistralrs-core/src/

Length of output: 75


I’ll inspect the surrounding code to locate both branches and any remaining .to_device(&Device::Cpu) calls.

#!/bin/bash
set -euxo pipefail

# 1. Show the logits processing snippet (around lines 423–427)
sed -n '350,450p' mistralrs-core/src/pipeline/mod.rs

# 2. Show the PagedAttention branch (around lines 564–570)
sed -n '550,650p' mistralrs-core/src/pipeline/mod.rs

# 3. Find all remaining `.to_device(&Device::Cpu)` occurrences
rg -n ".to_device\s*\(\s*&Device::Cpu\s*\)" -n mistralrs-core/src/pipeline/mod.rs

To pinpoint the discrepancy and verify if CPU placement is still used elsewhere, let’s dump both branches and search for any remaining .to_device(&Device::Cpu) calls:

#!/bin/bash
set -euxo pipefail

# 1. Show logits processing around lines 420–430
sed -n '410,440p' mistralrs-core/src/pipeline/mod.rs

# 2. Show PagedAttention branch around lines 560–580
sed -n '550,590p' mistralrs-core/src/pipeline/mod.rs

# 3. List all occurrences of to_device(&Device::Cpu)
rg -n "to_device\s*\(\s*&Device::Cpu\s*\)" mistralrs-core/src/pipeline/mod.rs || true

🏁 Script executed:

#!/usr/bin/env bash
set -euxo pipefail

# 1. Show logits processing around lines 420–430
sed -n '410,440p' mistralrs-core/src/pipeline/mod.rs

# 2. Show PagedAttention branch around lines 560–580
sed -n '550,590p' mistralrs-core/src/pipeline/mod.rs

# 3. List all occurrences of to_device(&Device::Cpu)
rg -n "to_device\s*\(\s*&Device::Cpu\s*\)" mistralrs-core/src/pipeline/mod.rs || true

Length of output: 3679


Inconsistent .to_device(&Device::Cpu) usage in logits processing

The normal CausalGeneration path (around lines 423–427) now drops the .to_device(&Device::Cpu) call, but the PagedAttention branch (line 568) still performs it—and there are other conversions at lines 384 and 540. Please verify whether downstream routines (e.g. sample_causal_gen) require CPU-resident tensors. Then unify both branches:

• Remove .to_device(&Device::Cpu) in the PagedAttention branch (and at lines 384/540) if GPU tensors are acceptable
• Otherwise, re-introduce the CPU conversion in the simplified branch

Files/locations to review:

  • mistralrs-core/src/pipeline/mod.rs: lines 423–427
  • mistralrs-core/src/pipeline/mod.rs: line 568
  • mistralrs-core/src/pipeline/mod.rs: lines 384, 540
mistralrs-quant/build.rs (1)

170-174: Compiling header files may duplicate symbols

Header sources are now passed to the Metal compiler as standalone translation units:

for src in HEADER_SOURCES {
    println!("cargo:rerun-if-changed=src/metal_kernels/{src}.metal");
}

and later:

for metal_file in HEADER_SOURCES {
    compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
}

If these headers define templated kernels or helper functions without static linkage
they may be emitted twice, producing duplicate-symbol errors at link time.

Consider adding them with -I/-include instead of compiling them directly.

mistralrs-quant/src/metal_kernels/bf16.metal (1)

47-56: Heavy use of C++17 type-traits may not compile under Metal

Metal shading language roughly aligns with C++14; helpers such as is_convertible_v, is_same_v and variable templates are C++17.
Unless you compile with -std=c++17-metal (not universally available), this section will fail.

If wide portability is desired, fall back to std::is_same<…, bool>::value + enable_if_t.

mistralrs-core/src/prefix_cacher.rs (1)

155-182: evict_caches can dereference a None head – guard early.

If current_on_device > n_on_device gets out of sync (e.g. manual current_on_device mutation or future parallelisation), lru_head might be None, leading to an infinite loop (while condition never falsifies and head is None).
Add an assert or early-return before the loop:

if self.lru_head.is_none() {
    debug_assert_eq!(self.current_on_device, 0, "Inconsistent LRU list");
    break;
}

Comment on lines 303 to 305
if let Some(ttft) = first_token_duration {
println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Invalid format string – code will not compile

println!("Time to first token: {:.2?}s", ttft.as_secs_f32());

{:?} is the Debug formatter, which cannot be combined with a precision specifier (.2).
The compiler will emit:

format argument must not be followed by `?`

Replace the debug specifier with the normal Display formatter:

-println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
+println!("Time to first token: {:.2}s", ttft.as_secs_f32());

Apply the same change in the vision branch (lines 550-552).

Comment on lines +1151 to +1165
let size = shape[axis];
let stride = strides[size];
let _bm = 32;
let bn = 32;
let stride_blocks = stride.div_ceil(bn);

encoder.set_buffer(0, Some(xs), xs_offset as u64);
encoder.set_buffer(1, Some(output), 0);

encoder.set_bytes(
2,
std::mem::size_of::<usize>() as u64,
&size as *const usize as *const _,
);
encoder.set_bytes(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Array-index bug: using size instead of axis

let stride = strides[size];

size is the length of the scan dimension, not its index.
Accessing strides[size] will:

  1. Panic with index out of bounds when size ≥ strides.len().
  2. Read an unrelated stride when the scan dimension length happens to be a valid index.

Fix:

-        let stride = strides[size];
+        let stride = strides[axis];

Add a unit test with a non-trivial stride to catch this class of error.

Comment on lines +59 to +67
if divisor > 1 {
if grid_x % divisor == 0 {
grid_x /= divisor;
divisor = 1;
} else if grid_y % divisor == 0 {
grid_y /= divisor;
divisor = 1;
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remaining divisor not validated

After the loop divisor may still be > 1, meaning the requested factorisation failed
silently. This produces wrong grid sizes and hard-to-trace GPU crashes.

     if divisor > 1 {
         if grid_x % divisor == 0 {
             grid_x /= divisor;
             divisor = 1;
         } else if grid_y % divisor == 0 {
             grid_y /= divisor;
             divisor = 1;
         }
     }
+
+    // Ensure the divisor was fully factored out
+    assert_eq!(divisor, 1, "Unable to factor divisor from shape");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if divisor > 1 {
if grid_x % divisor == 0 {
grid_x /= divisor;
divisor = 1;
} else if grid_y % divisor == 0 {
grid_y /= divisor;
divisor = 1;
}
}
if divisor > 1 {
if grid_x % divisor == 0 {
grid_x /= divisor;
divisor = 1;
} else if grid_y % divisor == 0 {
grid_y /= divisor;
divisor = 1;
}
}
// Ensure the divisor was fully factored out
assert_eq!(divisor, 1, "Unable to factor divisor from shape");

Comment on lines +1080 to +1106
}

let type_name = match ty {
DType::F32 => "float32",
DType::BF16 => "bfloat16",
DType::F16 => "float16",
DType::U8 => "uint8",
DType::I16 => "int16",
DType::I32 => "int32",
DType::I64 => "int64",
other => {
return Err(MetalKernelError::DTypeMismatch {
expected: vec![
DType::F32,
DType::F16,
DType::BF16,
DType::U8,
DType::I16,
DType::I32,
DType::I64,
],
got: other,
})
}
};
name.push_str(&format!("{type_name}_{type_name}"));

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

LogAddExp on integer dtypes is undefined

ScanType::LogAddExp is appended without restricting to floating-point types.
Metal kernels will likely perform illegal operations on u8/i16/i32/i64.

Guard the match:

if matches!(op, ScanType::LogAddExp) && !ty.is_float() {
    return Err(MetalKernelError::DTypeMismatch {});
}

or restrict expected accordingly.

Comment on lines +53 to +57
static constexpr constant U init = static_cast<U>(1.0f);

template <typename T> U operator()(U a, T b) { return a * b; }

U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); }
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

CumProd::init uses a float literal – breaks for integral/half types

static_cast<U>(1.0f) is safe only when U is a floating-point type.
For int, uint, char, half, etc., this introduces an implicit
float-to-int conversion which is (a) UB for bool, (b) a silent narrowing
conversion for integers, and (c) breaks constexpr evaluation on half types
prior to macOS 14.

-  static constexpr constant U init = static_cast<U>(1.0f);
+  static constexpr constant U init = static_cast<U>(1);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T> U operator()(U a, T b) { return a * b; }
U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); }
static constexpr constant U init = static_cast<U>(1);
template <typename T> U operator()(U a, T b) { return a * b; }
U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); }

Comment on lines +83 to +88
template <typename U> struct CumMax {
static constexpr constant U init = Limits<U>::min;

template <typename T> U operator()(U a, T b) { return (a >= b) ? a : b; }

U simd_scan(U x) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Initial value for CumMax is wrong for floating point – should be lowest() not min()

Limits<U>::min maps to the smallest positive value for floats
(FLT_MIN), not the most negative. Using it makes every max scan return
positive numbers even when the input contains negatives.

-  static constexpr constant U init = Limits<U>::min;
+  static constexpr constant U init = Limits<U>::lowest;   // requires Limits helper, or use numeric_limits<U>::lowest()

Do the same for CumMin with Limits<U>::maxLimits<U>::highest.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template <typename U> struct CumMax {
static constexpr constant U init = Limits<U>::min;
template <typename T> U operator()(U a, T b) { return (a >= b) ? a : b; }
U simd_scan(U x) {
template <typename U> struct CumMax {
static constexpr constant U init = Limits<U>::lowest; // requires Limits helper, or use std::numeric_limits<U>::lowest()
template <typename T> U operator()(U a, T b) { return (a >= b) ? a : b; }
U simd_scan(U x) {
// ...

Comment on lines +245 to +253
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values,
in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(values,
in + axis_size - offset - N_READS,
offset, axis_size, Op::init);
}
} else {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Possible pointer underflow for reverse loads

When offset >= axis_size, the expression
in + axis_size - offset - N_READS wraps around (unsigned underflow) and
produces a huge positive address → undefined behaviour / GPU fault.
The guard if ((offset + N_READS) < axis_size) is not sufficient once
offset == axis_size - N_READS because the else branch still applies the
same arithmetic.

Fix by clamping before pointer math:

-        load_safe<T, U, N_READS, reverse>(values,
-                                          in + axis_size - offset - N_READS,
+        size_t base = (offset + N_READS > axis_size)
+                          ? 0
+                          : axis_size - offset - N_READS;
+        load_safe<T, U, N_READS, reverse>(values,
+                                          in + base,
                                           offset, axis_size, Op::init);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values,
in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(values,
in + axis_size - offset - N_READS,
offset, axis_size, Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values,
in + axis_size - offset - N_READS);
} else {
size_t base = (offset + N_READS > axis_size)
? 0
: axis_size - offset - N_READS;
load_safe<T, U, N_READS, reverse>(values,
in + base,
offset, axis_size, Op::init);
}
} else {

Comment on lines +225 to +226
threadgroup U simdgroup_sums[32];

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

simdgroup_sums[32] can overflow when lsize.x > 1024

simd_groups = lsize.x / 32. Metal allows up to 1024 threads per
thread-group on recent GPUs, but future devices (and Apple Silicon’s
M-series in compute mode) may expose 1536.
If lsize.x == 1536, simd_groups == 48 → out-of-bounds write.

-  threadgroup U simdgroup_sums[32];
+  threadgroup U simdgroup_sums[simd_groups];

(or threadgroup_array<U, MAX_SIMD_GROUPS> with a static assert).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
threadgroup U simdgroup_sums[32];
threadgroup U simdgroup_sums[simd_groups];

Comment on lines +358 to +360
constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(U);
constexpr int n_simds = BN / N_READS;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

BN_pad mis-aligns shared memory for 64-bit types

BN_pad = 32 + 16 / sizeof(U) yields 34 for 64-bit data, which breaks
128-bit alignment required for double/ulong vectorised accesses and
causes bank conflicts on A-series GPUs.

-  constexpr int BN_pad = 32 + 16 / sizeof(U);
+  constexpr int BN_pad = 32 + (16 / sizeof(U));        // keep, *then*
+  static_assert((BN_pad % (16 / sizeof(U))) == 0, "BN_pad must align to 128-bit");

Alternatively compute round_up(BN, 16/sizeof(U)).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(U);
constexpr int n_simds = BN / N_READS;
constexpr int BN = 32;
- constexpr int BN_pad = 32 + 16 / sizeof(U);
+ constexpr int BN_pad = 32 + (16 / sizeof(U)); // keep, *then*
+ static_assert((BN_pad % (16 / sizeof(U))) == 0, "BN_pad must align to 128-bit");
constexpr int n_simds = BN / N_READS;

EricLBuehler and others added 7 commits May 21, 2025 12:38
@EricLBuehler EricLBuehler marked this pull request as ready for review May 21, 2025 16:54
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

♻️ Duplicate comments (13)
mistralrs-core/src/sampler.rs (3)

417-418: ⚠️ Potential issue

sample_fast drops probabilistic behaviour

The implementation always takes argmax of the filtered distribution; no weighted/multinomial sampling is performed. This makes the sampling deterministic even with non-zero temperature, which is inconsistent with the standard sampling behavior where temperature controls the randomness.

Consider reinstating true multinomial sampling on the masked distribution:

-        let next_token = probs.argmax(D::Minus1)?.to_scalar::<u32>()?;
+        // Sample from the distribution if temperature is set, otherwise use argmax
+        let next_token = if self.temperature.is_some() {
+            // Convert to CPU for sampling with WeightedIndex
+            let probs_vec: Vec<f32> = probs.to_vec1()?;
+            let distr = WeightedIndex::new(&probs_vec)?;
+            let mut rng_lock = rng.lock().expect("could not lock rng mutex");
+            distr.sample(&mut *rng_lock) as u32
+        } else {
+            probs.argmax(D::Minus1)?.to_scalar::<u32>()?
+        };

464-468: ⚠️ Potential issue

Incorrect log probability calculation

The implementation has two issues with log probabilities:

  1. When return_logprobs is false, it hardcodes logprob to 1.0 (line 467), which gives log(prob) = 0, breaking accuracy.
  2. When return_logprobs is true, it doesn't compute the actual log probability of the selected token.

This will cause incorrect log probability values to be returned, affecting applications that rely on accurate probability tracking.

Apply this fix to correctly compute the log probability of the selected token:

-            let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
+            // Find the actual logprob of the selected token
+            let logprob = result
+                .iter()
+                .find(|res| res.token == next_token)
+                .map(|res| res.logprob)
+                .unwrap_or_else(|| {
+                    // If token not in top-k, compute its logprob directly
+                    probs.i(next_token as i64)?.to_scalar::<f32>()?.log10()
+                });

-        } else {
-            (None, 1.)
+        } else {
+            // Always compute actual logprob even when not returning top logprobs
+            let prob = probs.i(next_token as i64)?.to_scalar::<f32>()?;
+            (None, prob.log10())

389-390: 🛠️ Refactor suggestion

Missing probability re-normalization after masking

The implementation masks probabilities below thresholds to zero but doesn't re-normalize the resulting distribution. This is fine for argmax but would be incorrect for multinomial sampling, as sampling requires a properly normalized probability distribution.

Add re-normalization after each masking operation:

        probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
+        // Re-normalize probabilities to sum to 1
+        let probs_sum = probs.sum_all(true)?;
+        probs = (probs / probs_sum)?;

        // Similar changes needed after Top-P and Min-P masking

Also applies to: 406-407, 414-415

mistralrs-quant/src/metal_kernels/utils.rs (1)

59-67: ⚠️ Potential issue

Remaining divisor not validated

After the loop divisor may still be > 1, meaning the requested factorization failed silently. This produces wrong grid sizes and hard-to-trace GPU crashes.

     if divisor > 1 {
         if grid_x % divisor == 0 {
             grid_x /= divisor;
             divisor = 1;
         } else if grid_y % divisor == 0 {
             grid_y /= divisor;
             divisor = 1;
         }
     }
+
+    // Ensure the divisor was fully factored out
+    assert_eq!(divisor, 1, "Unable to factor divisor from shape");
mistralrs-quant/src/metal_kernels/mod.rs (2)

1113-1120: LogAddExp accepts integer dtypes – undefined behaviour

ScanType::LogAddExp is currently allowed for every DType in the match below.
The Metal implementation relies on floating-point math; passing u8/i32/… will generate incorrect results or trigger invalid-op exceptions on GPU.

Guard the call before building the kernel name:

if matches!(op, ScanType::LogAddExp) && !ty.is_float() {
    return Err(MetalKernelError::DTypeMismatch {
        expected: vec![DType::F32, DType::F16, DType::BF16],
        got: ty,
    });
}

(or restrict the type_name match arm to float dtypes only).
[ suggest_essential_refactor ]

Also applies to: 1121-1142


1189-1194: ⚠️ Potential issue

Stride-lookup still indexes with length instead of axis

stride = strides[size]; repeats the off-by-one/ OOB bug previously raised – size is the length of the scan-axis, not its position.
This will (a) read the wrong stride for most tensors and (b) panic when size ≥ strides.len().

-        let stride = strides[size];
+        let stride = strides[axis];

Please add a regression test with a non-unit stride to prevent future re-introductions.

mistralrs-quant/src/utils/ops.rs (1)

1322-1336: ⚠️ Potential issue

CPU CumSum still assumes the scan axis is the innermost – same bug as flagged earlier

The inner loops step through the buffer with
let base = block * axis_len; and then index input[base + j].
This is only correct when the axis has a stride of 1 (i.e. it is the last / innermost dimension).
For any other contiguous tensor (e.g. shape (4, 5) and axis = 0), elements that belong to different columns are interleaved in memory, so the current algorithm mixes rows and produces wrong results.

The earlier review already highlighted this; the implementation has not changed.

Suggested fixes:

  1. Compute the true stride of the axis from l1.stride()[axis] and use it when walking the tensor, or
  2. Call .contiguous() / .permute() to move the axis to the last position before performing the scan.

Please add a regression test on a 2-D tensor with axis = 0 and axis = 1 to verify correctness.

mistralrs-quant/src/metal_kernels/scan_impl.metal (5)

53-55: ⚠️ Potential issue

CumProd::init uses a float literal – breaks for integral/half types

static_cast<U>(1.0f) converts through float, which is a narrowing / UB for some integer and half types and prevents constexpr evaluation.
Please use an integer literal instead:

-  static constexpr constant U init = static_cast<U>(1.0f);
+  static constexpr constant U init = static_cast<U>(1);

84-86: ⚠️ Potential issue

Initial value for CumMax / CumMin is wrong for floats

Limits<U>::min and Limits<U>::max return the smallest positive and largest positive values for floating-point types.
Use lowest() / highest() (or numeric_limits<U>::lowest()) to obtain the true extrema; otherwise negative inputs are mishandled.


245-253: ⚠️ Potential issue

Pointer underflow when reverse == true and offset ≥ axis_size

in + axis_size - offset - N_READS is evaluated before the bounds check in the else branch.
When offset ≥ axis_size this wraps around and produces an invalid device address → undefined behaviour / potential GPU fault.

Refer to the previous review’s suggested fix that computes a clamped base before doing pointer arithmetic.


415-417: ⚠️ Potential issue

Incorrect memory flag in simdgroup_barrier

simdgroup_barrier only accepts mem_flags::mem_none; passing mem_threadgroup is rejected by the MSL compiler on macOS 14 and causes driver warnings on iOS 17.

-    simdgroup_barrier(mem_flags::mem_threadgroup);
+    simdgroup_barrier(mem_flags::mem_none);

223-226: ⚠️ Potential issue

Possible overflow: simdgroup_sums[32] is too small for large thread-groups

simd_groups = lsize.x / 32, yet the scratch buffer is hard-coded to 32 entries.
If a future GPU (or a debug build) launches 1 024+ threads per thread-group, the array is written out of bounds.

-threadgroup U simdgroup_sums[32];
+threadgroup U simdgroup_sums[simd_groups];
+static_assert(simd_groups <= 32, "Unexpectedly large simd_group count");

(or allocate MAX_SIMD_GROUPS with an assertion).

mistralrs-quant/src/metal_kernels/utils.metal (1)

988-994: Bitwise '&' used instead of logical '&&' in template constraints.

enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> uses & which happens to work but is semantically incorrect; && expresses intent and avoids subtle type-promotion surprises.

-  metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
+  metal::enable_if_t<metal::is_integral_v<T> && !metal::is_signed_v<T>, T>
-  metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
+  metal::enable_if_t<metal::is_integral_v<T> && metal::is_signed_v<T>, T>
🧹 Nitpick comments (10)
scripts/bench.py (4)

1-4: Improve script documentation with usage instructions

These commented commands provide execution examples but lack context. Consider converting them to proper docstring documentation with explanations of what each command does.

-# cargo run --release --features metal '--' --port 1234 --isq 8 --paged-attn --max-seqs 1000 plain -m ../hf_models/llama3.2_3b --max-seq-len 131072
-# cargo run --release --features metal '--' --port 1234 --paged-attn --max-seqs 1000 plain -m mlx-community/Mistral-7B-Instruct-v0.3-4bit --max-seq-len 131072
-# ./llama-server -m ../gguf_models/Llama-3.2-3B-Instruct-Q8_0.gguf
-# mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit --port 8080
+"""
+Benchmark script for asynchronous load testing against a local language model server.
+
+Example server startup commands:
+
+# mistral.rs with metal backend
+cargo run --release --features metal '--' --port 1234 --isq 8 --paged-attn --max-seqs 1000 plain -m ../hf_models/llama3.2_3b --max-seq-len 131072
+cargo run --release --features metal '--' --port 1234 --paged-attn --max-seqs 1000 plain -m mlx-community/Mistral-7B-Instruct-v0.3-4bit --max-seq-len 131072
+
+# llama.cpp server
+./llama-server -m ../gguf_models/Llama-3.2-3B-Instruct-Q8_0.gguf
+
+# MLX server
+mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit --port 8080
+"""

13-15: Make benchmark parameters configurable

Constants are hardcoded, limiting script flexibility. Consider making these configurable via command-line arguments.

-NUM_USERS = 8
-REQUESTS_PER_USER = 8
-PORT = 1234
+import argparse
+
+# Default configuration
+DEFAULT_NUM_USERS = 8
+DEFAULT_REQUESTS_PER_USER = 8
+DEFAULT_PORT = 1234
+
+# Parse command line arguments
+def parse_args():
+    parser = argparse.ArgumentParser(description="Benchmark a local LLM server with concurrent requests")
+    parser.add_argument("--users", type=int, default=DEFAULT_NUM_USERS, 
+                        help=f"Number of concurrent users (default: {DEFAULT_NUM_USERS})")
+    parser.add_argument("--requests", type=int, default=DEFAULT_REQUESTS_PER_USER,
+                        help=f"Number of requests per user (default: {DEFAULT_REQUESTS_PER_USER})")
+    parser.add_argument("--port", type=int, default=DEFAULT_PORT,
+                        help=f"Server port (default: {DEFAULT_PORT})")
+    return parser.parse_args()

42-43: Make OpenAI client configuration more flexible

The client is hardcoded with fixed values. Consider making the base URL configurable.

-# Use the async-capable client
-client = AsyncOpenAI(api_key="foobar", base_url=f"http://localhost:{PORT}/v1/")
+def create_client(port):
+    """Create an async OpenAI client configured for the local server."""
+    return AsyncOpenAI(
+        api_key="foobar",  # Dummy API key for local server
+        base_url=f"http://localhost:{port}/v1/"
+    )

81-102: Enhance main function with more detailed metrics

The current implementation provides basic metrics. Consider adding more detailed statistics like percentiles and throughput over time.

 async def main() -> None:
     """
     Computes and prints overall average request time, total requests, and average T/s.
     """
-    system_prompt = None  # "You are a helpful assistant."
-    user_message = "Say hello!"
+    args = parse_args()
+    system_prompt = None  # "You are a helpful assistant."
+    user_message = "Say hello!"
+    client = create_client(args.port)
 
-    tasks = [user_task(client, system_prompt, user_message) for _ in range(NUM_USERS)]
+    print(f"Starting benchmark with {args.users} users, {args.requests} requests per user...")
+    start_time = time.perf_counter()
+    
+    tasks = [user_task(client, system_prompt, user_message) for _ in range(args.users)]
     all_results_nested = await asyncio.gather(*tasks)
     all_results = [item for sublist in all_results_nested for item in sublist]
+    
+    total_time = time.perf_counter() - start_time
 
     total_requests = len(all_results)
-    total_time = sum(elapsed for _, elapsed, _ in all_results)
+    request_times = [elapsed for _, elapsed, _ in all_results if _ is not None]
+    successful_requests = len(request_times)
+    avg_time = sum(request_times) / successful_requests if successful_requests else 0.0
     total_tokens = sum(tokens for _, _, tokens in all_results)
-    avg_time = total_time / total_requests if total_requests else 0.0
-    avg_tps = total_tokens / total_time if total_time > 0 else 0.0
+    overall_tps = total_tokens / total_time if total_time > 0 else 0.0
+    
+    # Calculate percentiles if we have results
+    if request_times:
+        request_times.sort()
+        p50 = request_times[len(request_times) // 2]
+        p90 = request_times[int(len(request_times) * 0.9)]
+        p99 = request_times[int(len(request_times) * 0.99)]
+    else:
+        p50 = p90 = p99 = 0
 
     print(f"Total requests: {total_requests}")
+    print(f"Successful requests: {successful_requests}")
+    print(f"Success rate: {successful_requests/total_requests*100:.2f}%")
     print(f"Average request time: {avg_time:.2f}s")
+    print(f"Percentiles: p50={p50:.2f}s, p90={p90:.2f}s, p99={p99:.2f}s")
     print(f"Total tokens: {total_tokens}")
-    print(f"Average tokens per second (T/s): {avg_tps:.2f}")
+    print(f"Overall tokens per second: {overall_tps:.2f} T/s")
+    print(f"Total benchmark time: {total_time:.2f}s")
mistralrs-quant/src/metal_kernels/copy.metal (1)

5-17: Macro naming explosion – consider normalising generated kernel identifiers

The two macro layers concatenate tname twice (instantiate_copy_same(itname ##itname, …) and instantiate_copy_all(itname ##… )).
For several invocations this produces identifiers such as gg1_copyfloat16float16, g1_copybool_bool_, etc.

These very long (and sometimes duplicated) names are legal, but:

  • They become cumbersome to read in build logs and GPU profiling tools.
  • They increase the probability of hitting the 1024-byte symbol limit on some platforms.
  • Future greps / nm searches become noisier.

If keeping the verbose name is not a hard requirement, you could spare a few characters:

- instantiate_copy_same(itname ##itname, itype)
+ instantiate_copy_same(itname, itype)

and likewise drop the double prefix in instantiate_copy_all.
No functional change – purely a QoL / maintainability tweak.

Also applies to: 37-51

mistralrs-quant/src/metal_kernels/quantized.metal (1)

1614-1623: Minor: remove redundant thread qualifier

Declaring wl_ptrs/… as thread const device uint8_t * is legal, but thread is implicit for local variables; omitting it shortens the type without changing semantics:

-thread const device uint8_t *wl_ptrs[results_per_simdgroup];
+const device uint8_t *wl_ptrs[results_per_simdgroup];

Same for sl_ptrs, bl_ptrs.

mistralrs-core/src/pipeline/paths.rs (1)

511-544: Avoid cloning large fallback strings unnecessarily

chat_template_fallback.cloned() creates an owned String even though the result is only used for a read.
Borrowing is sufficient and avoids an allocation:

-match chat_template_fallback.cloned() {
-    Some(t) => { /* uses t */ }
+match chat_template_fallback {
+    Some(t) => { /* uses *t */ }

A tiny optimisation, but worth it in hot-start scenarios where templates are large.

mistralrs-quant/src/metal_kernels/scan.metal (1)

65-77: Inconsistent kernel naming prod_bool__bool_

The generated host-name prod_bool__bool_ contains a double underscore and a trailing underscore, unlike the other instantiations (e.g. prod_uint8_uint8).
If the Rust side ever tries to build the kernel name programmatically (prod_bool_bool), it will miss this variant.

Confirm that the extra underscores are intentional; otherwise rename to prod_bool_bool.

mistralrs-quant/src/metal_kernels/sort_impl.metal (1)

26-29: Limits<T>::max as sentinel breaks descending inputs for floating-point types

The sentinel init = Limits<T>::max is FLT_MAX/DBL_MAX, not +∞.
For LessThan this is acceptable, but if the comparison order is ever flipped (e.g. descending sort), the sentinel will incorrectly dominate the result set.

Consider using numeric_limits<T>::infinity() or specialising for float types.

mistralrs-quant/src/metal_kernels/copy_impl.metal (1)

211-214: Pointer arithmetic on device memory defeats alias analysis

src += src_offset; dst += dst_offset;

Modifying the raw pointer hides the original base address from the compiler, restricting optimisation opportunities and sometimes violating Metal’s “do not form derived pointers” guidance.

Consider keeping the bases immutable and offsetting in the index math instead:

- src += src_offset;
- dst += dst_offset;- dst[idx.y] = src[idx.x];
+ dst[dst_offset + idx.y] = src[src_offset + idx.x];

This is also consistent with the fixed-stride variants above.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 089b617 and dd59c70.

📒 Files selected for processing (65)
  • Cargo.toml (1 hunks)
  • examples/server/chat.py (1 hunks)
  • mistralrs-core/src/dummy_paged_attention/scheduler.rs (5 hunks)
  • mistralrs-core/src/engine/logger.rs (5 hunks)
  • mistralrs-core/src/engine/mod.rs (2 hunks)
  • mistralrs-core/src/kv_cache/rotating_cache.rs (1 hunks)
  • mistralrs-core/src/kv_cache/single_cache.rs (1 hunks)
  • mistralrs-core/src/models/deepseek2.rs (2 hunks)
  • mistralrs-core/src/models/deepseek3.rs (2 hunks)
  • mistralrs-core/src/models/gemma.rs (1 hunks)
  • mistralrs-core/src/models/gemma2.rs (1 hunks)
  • mistralrs-core/src/models/llama.rs (2 hunks)
  • mistralrs-core/src/models/mistral.rs (2 hunks)
  • mistralrs-core/src/models/mixtral.rs (2 hunks)
  • mistralrs-core/src/models/phi2.rs (2 hunks)
  • mistralrs-core/src/models/phi3.rs (1 hunks)
  • mistralrs-core/src/models/phi3_5_moe.rs (2 hunks)
  • mistralrs-core/src/models/qwen2.rs (2 hunks)
  • mistralrs-core/src/models/qwen3.rs (2 hunks)
  • mistralrs-core/src/models/qwen3_moe.rs (2 hunks)
  • mistralrs-core/src/models/starcoder2.rs (1 hunks)
  • mistralrs-core/src/paged_attention/scheduler.rs (5 hunks)
  • mistralrs-core/src/pipeline/ggml.rs (1 hunks)
  • mistralrs-core/src/pipeline/gguf.rs (1 hunks)
  • mistralrs-core/src/pipeline/macros.rs (2 hunks)
  • mistralrs-core/src/pipeline/mod.rs (2 hunks)
  • mistralrs-core/src/pipeline/normal.rs (1 hunks)
  • mistralrs-core/src/pipeline/paths.rs (5 hunks)
  • mistralrs-core/src/pipeline/vision.rs (1 hunks)
  • mistralrs-core/src/sampler.rs (4 hunks)
  • mistralrs-core/src/scheduler/default_scheduler.rs (7 hunks)
  • mistralrs-core/src/scheduler/mod.rs (2 hunks)
  • mistralrs-core/src/vision_models/gemma3/config.rs (1 hunks)
  • mistralrs-core/src/vision_models/gemma3/text.rs (1 hunks)
  • mistralrs-core/src/vision_models/llama4/config.rs (1 hunks)
  • mistralrs-core/src/vision_models/llama4/text.rs (1 hunks)
  • mistralrs-core/src/vision_models/mllama/config.rs (1 hunks)
  • mistralrs-core/src/vision_models/mllama/text.rs (1 hunks)
  • mistralrs-core/src/vision_models/phi3/mod.rs (2 hunks)
  • mistralrs-core/src/vision_models/phi4/config.rs (1 hunks)
  • mistralrs-core/src/vision_models/phi4/mod.rs (1 hunks)
  • mistralrs-core/src/vision_models/qwen2_5_vl/config.rs (1 hunks)
  • mistralrs-core/src/vision_models/qwen2_5_vl/text.rs (1 hunks)
  • mistralrs-core/src/vision_models/qwen2vl/config.rs (1 hunks)
  • mistralrs-core/src/vision_models/qwen2vl/text.rs (1 hunks)
  • mistralrs-quant/build.rs (3 hunks)
  • mistralrs-quant/kernels/marlin/marlin_kernel.cu (47 hunks)
  • mistralrs-quant/src/lib.rs (1 hunks)
  • mistralrs-quant/src/metal_kernels/bf16.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/bitwise.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/copy.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/copy_impl.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/mod.rs (2 hunks)
  • mistralrs-quant/src/metal_kernels/quantized.metal (2 hunks)
  • mistralrs-quant/src/metal_kernels/scan.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/scan_impl.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/sort.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/sort_impl.metal (1 hunks)
  • mistralrs-quant/src/metal_kernels/utils.metal (2 hunks)
  • mistralrs-quant/src/metal_kernels/utils.rs (1 hunks)
  • mistralrs-quant/src/safetensors.rs (1 hunks)
  • mistralrs-quant/src/utils/mod.rs (1 hunks)
  • mistralrs-quant/src/utils/ops.rs (5 hunks)
  • scripts/bench.py (1 hunks)
  • scripts/convert_awq_marlin.py (8 hunks)
✅ Files skipped from review due to trivial changes (20)
  • examples/server/chat.py
  • mistralrs-quant/src/metal_kernels/bitwise.metal
  • mistralrs-core/src/models/phi3.rs
  • mistralrs-core/src/vision_models/gemma3/config.rs
  • mistralrs-core/src/pipeline/vision.rs
  • mistralrs-core/src/vision_models/qwen2vl/config.rs
  • mistralrs-core/src/models/gemma.rs
  • mistralrs-core/src/models/gemma2.rs
  • mistralrs-core/src/vision_models/mllama/config.rs
  • mistralrs-core/src/vision_models/llama4/config.rs
  • mistralrs-core/src/models/starcoder2.rs
  • mistralrs-core/src/vision_models/phi4/config.rs
  • mistralrs-core/src/vision_models/qwen2_5_vl/config.rs
  • Cargo.toml
  • mistralrs-quant/src/lib.rs
  • mistralrs-core/src/pipeline/ggml.rs
  • mistralrs-core/src/pipeline/gguf.rs
  • mistralrs-core/src/pipeline/normal.rs
  • mistralrs-quant/kernels/marlin/marlin_kernel.cu
  • mistralrs-quant/src/metal_kernels/sort.metal
🚧 Files skipped from review as they are similar to previous changes (4)
  • mistralrs-core/src/pipeline/mod.rs
  • mistralrs-quant/build.rs
  • mistralrs-quant/src/utils/mod.rs
  • mistralrs-quant/src/metal_kernels/bf16.metal
🧰 Additional context used
🧬 Code Graph Analysis (6)
mistralrs-core/src/kv_cache/rotating_cache.rs (1)
mistralrs-core/src/kv_cache/single_cache.rs (1)
  • all_data (41-43)
mistralrs-core/src/dummy_paged_attention/scheduler.rs (3)
mistralrs-core/src/scheduler/mod.rs (1)
  • schedule (55-55)
mistralrs-core/src/scheduler/default_scheduler.rs (2)
  • schedule (206-301)
  • schedule (311-315)
mistralrs-core/src/paged_attention/scheduler.rs (2)
  • schedule (67-239)
  • schedule (369-373)
mistralrs-core/src/pipeline/macros.rs (1)
mistralrs-core/src/pipeline/paths.rs (1)
  • get_xlora_paths (55-307)
mistralrs-core/src/sampler.rs (6)
mistralrs-core/src/pipeline/mod.rs (7)
  • logits (433-436)
  • logits (443-455)
  • logits (465-480)
  • logits (581-584)
  • logits (591-601)
  • logits (611-626)
  • tokenizer (153-153)
mistralrs-core/src/sequence.rs (1)
  • return_logprobs (785-787)
mistralrs-core/src/pipeline/ggml.rs (1)
  • tokenizer (500-502)
mistralrs-core/src/pipeline/gguf.rs (1)
  • tokenizer (661-663)
mistralrs-core/src/pipeline/normal.rs (1)
  • tokenizer (960-962)
mistralrs-core/src/pipeline/vision.rs (1)
  • tokenizer (819-821)
mistralrs-core/src/kv_cache/single_cache.rs (1)
mistralrs-core/src/kv_cache/rotating_cache.rs (1)
  • all_data (47-49)
mistralrs-core/src/pipeline/paths.rs (2)
mistralrs-core/src/model_loader.rs (3)
  • quantized_filename (337-340)
  • quantized_filename (362-365)
  • quantized_filename (395-398)
mistralrs-core/src/toml_selector.rs (3)
  • quantized_filename (726-729)
  • quantized_filename (753-756)
  • quantized_filename (786-789)
🪛 Ruff (0.11.9)
scripts/convert_awq_marlin.py

13-13: Undefined name List

(F821)


16-16: Undefined name List

(F821)

⏰ Context from checks skipped due to timeout of 90000ms (5)
  • GitHub Check: Test Suite (macOS-latest, stable)
  • GitHub Check: Test Suite (windows-latest, stable)
  • GitHub Check: Docs
  • GitHub Check: Test Suite (ubuntu-latest, stable)
  • GitHub Check: Check (windows-latest, stable)
🔇 Additional comments (74)
mistralrs-core/src/kv_cache/rotating_cache.rs (1)

47-48: Improved API ergonomics with better return type

The change from returning &Option<Tensor> to Option<&Tensor> is a significant improvement. Using self.all_data.as_ref() is the idiomatic Rust approach for this conversion, making the API more intuitive for callers who can now directly use the option without additional dereferencing.

mistralrs-core/src/kv_cache/single_cache.rs (1)

41-42: Improved API ergonomics with better return type

This change correctly transforms the return type from &Option<Tensor> to Option<&Tensor>, making the API more ergonomic. The implementation using self.all_data.as_ref() follows Rust best practices. This change maintains consistency with the parallel modification in rotating_cache.rs.

scripts/convert_awq_marlin.py (5)

79-81: Improved function signature format.

The reformatted function signature for marlin_zero_points is now more readable with parameters on separate lines.


102-104: Improved function signature format.

The reformatted function signature for awq_to_marlin_zero_points is now more readable with parameters on separate lines.


56-60: Improved assertion message formatting.

The assertion message is now properly enclosed in parentheses, which aligns with Python style best practices for multi-line statements.


198-206: Consistent assertion message formatting.

All assertion messages are now properly enclosed in parentheses, which improves code consistency and readability.


177-191: Improved argument parser formatting.

The argument parser configuration is now more consistently formatted with clear indentation, making it easier to read.

scripts/bench.py (2)

18-40: Consider simplifying the log_response function

The function is comprehensive but verbose. For benchmarking purposes, a more concise logging approach might be sufficient.


104-105: 🛠️ Refactor suggestion

Update script execution to use the new argument parsing

Update the main entry point to use the new command-line argument parsing.

 if __name__ == "__main__":
     asyncio.run(main())

Likely an incorrect or invalid review comment.

mistralrs-core/src/vision_models/qwen2vl/text.rs (1)

395-396: Good update to use the model's quantization configuration

Updating the lm_head initialization to use the model's quantization configuration ensures consistent quantization handling across all model components. This change properly propagates the configuration instead of using a hardcoded None.

mistralrs-core/src/vision_models/phi4/mod.rs (1)

412-413: Good update to use the model's quantization configuration

Updating the lm_head initialization to use the model's quantization configuration ensures consistent quantization handling across all model components. This alignment with other model parts improves consistency in the quantization behavior.

mistralrs-core/src/vision_models/mllama/text.rs (1)

580-581: Good update to use the model's quantization configuration

Updating the lm_head initialization to use the model's quantization configuration ensures consistent quantization handling across all model components. This change correctly propagates the quantization settings to all relevant layers.

mistralrs-core/src/vision_models/gemma3/text.rs (1)

494-494: Improved quantization support for language model head.

This change properly propagates the quantization configuration to the language model head, ensuring consistent quantization behavior across all model components. Previously, the language model head was always initialized with None for quantization, preventing it from being quantized even when other layers were.

mistralrs-core/src/vision_models/llama4/text.rs (1)

614-614: Enabled quantization for language model head.

This change properly propagates the quantization configuration to the language model head, ensuring consistent quantization behavior across all model components. This aligns with similar changes across multiple model implementations in the codebase.

mistralrs-core/src/vision_models/qwen2_5_vl/text.rs (1)

399-399: Applied consistent quantization to language model head.

This change ensures the language model head properly receives the model's quantization configuration rather than hardcoding None. This allows the head layer to be quantized consistently with the rest of the model.

mistralrs-core/src/models/phi3_5_moe.rs (2)

49-50: Added serde alias for improved config compatibility.

The serde alias "quantization" allows the model to be loaded from JSON configurations that use either "quantization_config" or "quantization" as the field name, providing backward compatibility with different model serialization formats.


661-661: Propagated quantization config to language model head.

This change enables proper quantization of the language model head using the same configuration as the rest of the model, ensuring consistent quantization behavior.

mistralrs-core/src/models/mixtral.rs (2)

48-49: LGTM: Added serde alias "quantization" for the quantization_config field.

This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.


584-584: LGTM: Now propagating quantization config to language model head.

This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded None.

mistralrs-core/src/models/qwen3.rs (2)

59-60: LGTM: Added serde alias "quantization" for the quantization_config field.

This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.


492-492: LGTM: Now propagating quantization config to language model head.

This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded None.

mistralrs-core/src/models/llama.rs (2)

47-48: LGTM: Added serde alias "quantization" for the quantization_config field.

This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.


385-385: LGTM: Now propagating quantization config to language model head.

This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded None.

mistralrs-core/src/models/deepseek2.rs (2)

96-97: LGTM: Added serde alias "quantization" for the quantization_config field.

This enhances deserialization flexibility by allowing model configs to use either "quantization_config" or "quantization" key.


792-792: LGTM: Now propagating quantization config to language model head.

This change enables consistent quantization by passing the configuration to the LM head instead of hardcoded None.

mistralrs-quant/src/safetensors.rs (1)

209-209: Code simplification improves readability.

Simplified the tensor loading logic by directly delegating to the underlying load implementation with the provided dtype. This removes an unnecessary intermediate step where a tensor was first loaded and then potentially converted to a requested dtype.

mistralrs-core/src/vision_models/phi3/mod.rs (2)

78-79: Adds serialization flexibility with serde alias.

Adding the alias = "quantization" attribute allows the model to deserialize configuration that uses either quantization_config or quantization as the field name. This improves compatibility with different model file formats.


1030-1031: Propagates quantization configuration to language model head.

The language model head now receives the proper quantization configuration instead of None. This ensures consistent quantization behavior across all model components, including the output layer.

mistralrs-core/src/engine/mod.rs (2)

17-17: Makes IntervalLogger publicly accessible.

Changed the IntervalLogger import to be public, enabling its use outside the engine module. This supports the integration of logging into other components like the scheduler.


175-175: Passes logger to scheduler for enhanced metrics tracking.

Updated to pass the engine's logger to the scheduler, enabling the scheduler to update metrics like number of running and waiting sequences. This change aligns with the updated scheduler interface.

mistralrs-core/src/scheduler/mod.rs (2)

9-9: Adds import for IntervalLogger.

Imported IntervalLogger from the engine module to support the updated scheduler interface.


55-55: Updates Scheduler trait to include logging capability.

The schedule method now requires an IntervalLogger reference, enabling all scheduler implementations to update metrics about running and waiting sequences. This enhances observability of the scheduling process.

mistralrs-core/src/models/qwen3_moe.rs (2)

57-58: Added serde alias for backward compatibility.

The #[serde(alias = "quantization")] attribute allows the field to be deserialized from JSON/YAML with either "quantization_config" or "quantization" as the key, improving backward compatibility with existing model configurations.


695-696: Propagating quantization config to language model head.

The change ensures that quantization settings are consistently applied to the language model head (lm_head), where previously a hardcoded None was used. This allows for proper quantization across all model components.

mistralrs-core/src/models/qwen2.rs (2)

43-44: Added serde alias for backward compatibility.

The #[serde(alias = "quantization")] attribute allows deserializing the quantization configuration from either "quantization_config" or "quantization" field names, enhancing compatibility with various model configuration formats.


423-424: Propagating quantization config to language model head.

This change replaces a hardcoded None with &cfg.quantization_config, ensuring that the language model head properly respects the quantization settings provided to the model.

mistralrs-core/src/models/deepseek3.rs (2)

96-97: Added serde alias for backward compatibility.

The #[serde(alias = "quantization")] attribute enables deserialization from either "quantization_config" or "quantization" field names, improving compatibility with different model configuration formats.


845-846: Propagating quantization config to language model head.

This change ensures that the language model head (lm_head) consistently uses the same quantization configuration as the rest of the model components, replacing a hardcoded None reference.

mistralrs-core/src/engine/logger.rs (7)

15-16: Added tracking for running and waiting sequences.

These new atomic counters enable monitoring the number of sequences in different states in the scheduler, which is valuable for debugging and performance monitoring.


26-28: Initialized atomic counters for sequence tracking.

Proper initialization of the atomic counters for tracking running and waiting sequences, consistent with the initialization pattern used for other counters.


33-34: Added thread-local clones of atomic counters.

These thread-local clones enable the background logging thread to safely access the sequence state counters without risking thread safety issues.


46-47: Reading scheduler statistics in the logging thread.

This code loads the current values of running and waiting sequences to be displayed in periodic log messages, providing runtime visibility into scheduler state.


51-52: Enhanced logging output with scheduler information.

The log format now includes the count of running and waiting sequences, providing runtime visibility into the scheduler's workload and potential bottlenecks.


64-65: Added fields to struct for tracking sequence counts.

These fields complete the implementation of sequence tracking by storing the atomic counters in the IntervalLogger struct.


86-92: Added public setters for scheduler metrics.

These methods allow the scheduler to update the number of running and waiting sequences, completing the integration between the scheduler and logging components.

mistralrs-core/src/models/mistral.rs (2)

45-45: Adds serde deserialization compatibility for legacy field name.

Adding the alias = "quantization" attribute allows the field to be deserialized from JSON that uses either "quantization_config" or the legacy "quantization" name, improving backward compatibility.


451-452: Enables quantization for the language model head layer.

Now the lm_head initialization properly uses the model's quantization configuration instead of passing None. This change ensures consistent quantization behavior across all model components.

mistralrs-core/src/models/phi2.rs (2)

54-55: Adds serde deserialization compatibility for legacy field name.

Adding the alias = "quantization" attribute allows the field to be deserialized from JSON that uses either "quantization_config" or the legacy "quantization" name, improving backward compatibility.


530-531: Enables quantization for the language model head layer.

Now the lm_head initialization properly uses the model's quantization configuration instead of passing None. This change ensures consistent quantization behavior across all model components.

mistralrs-core/src/pipeline/macros.rs (4)

99-101: Updates argument passing style to match function signatures.

Changed from passing references to Options (&Option<T>) to passing optional references (Option<&T>) using .as_ref(), aligning with updated function signatures in paths.rs.


107-112: Updates xlora-related parameters to use optional references.

Changed from passing references to Options to passing optional references using .as_ref(), creating a more idiomatic API design and matching updated function signatures.


286-288: Updates argument passing style for model paths in GGUF context.

Changed from passing Option<T> directly to wrapping values in Some(&T), ensuring type compatibility with updated function signatures expecting Option<&T> parameters.


295-300: Updates xlora-related parameter passing in GGUF context.

Changed from passing Options directly to using .as_ref() for consistent argument passing style that aligns with the updated function signatures in paths.rs.

mistralrs-core/src/paged_attention/scheduler.rs (5)

20-21: Adds IntervalLogger import for enhanced monitoring.

Added import for the IntervalLogger that will be used to track and report scheduler metrics.


67-67: Updates scheduler interface to integrate logging.

The schedule method now accepts a reference to an IntervalLogger parameter to enable reporting of scheduler metrics.


123-125: Adds logging for scheduler status after promotion phase.

After sequences are either scheduled or ignored during promotion from waiting to running, the logger is updated with current counts of running and waiting sequences, enhancing observability.


230-232: Adds logging for scheduler status after main scheduling logic.

The logger is updated with the final counts of running and waiting sequences after all scheduling decisions have been made, providing visibility into the scheduler's state.


369-372: Updates Scheduler trait implementation to propagate logger.

The trait implementation now passes the logger through to the concrete implementation, maintaining consistent logging behavior.

mistralrs-core/src/dummy_paged_attention/scheduler.rs (1)

67-67: LGTM! Consistent logger integration

The IntervalLogger integration is well-implemented, with appropriate updates at key points in the scheduling pipeline. This change aligns perfectly with the pattern established in other scheduler implementations.

Also applies to: 123-124, 230-231, 369-371

mistralrs-core/src/scheduler/default_scheduler.rs (1)

206-206: LGTM! Thorough logger integration

The IntervalLogger integration is well-implemented with updates at all key decision points in the scheduling pipeline. The logger calls are strategically placed to ensure accurate reporting of running and waiting sequence counts throughout the execution path.

Also applies to: 218-219, 233-234, 248-249, 284-286, 311-313

mistralrs-core/src/sampler.rs (1)

794-802: LGTM! Feature-gated fast sampler

The conditional return based on the metal feature flag is a good approach. This properly preserves the original sampling path for non-metal builds while enabling the optimized path for metal, avoiding the unreachable code issue that would occur with an unconditional return.

mistralrs-quant/src/metal_kernels/utils.rs (1)

95-137: LGTM! Well-implemented thread-group sizing function

The get_block_dims function provides efficient 3D thread-group sizing with power-of-two thread counts while respecting dimension extents. This is a good port of MLX's approach and will help optimize kernel launch configurations.

mistralrs-quant/src/metal_kernels/copy.metal (1)

52-63: LGTM – exhaustive instantiation achieved

The final block instantiates the kernels for every primitive/bfloat16/half type across the three common group sizes.
No gaps spotted, order is consistent with the rest of the code-base.

mistralrs-quant/src/metal_kernels/quantized.metal (1)

1598-1644: Pointer-increment logic: verify overflow on partial K blocks

The new cache-friendly loop assumes that in_vec_size is an exact multiple of block_size (values_per_thread * SIMD_SIZE).
If this is not guaranteed, the final iteration will read past the end of x, w, scales, and biases.

The slow path (qmv_impl) handles the “tail” with load_vector_safe, but qmv_fast_impl does not.

Please double-check callers to ensure the fast kernel is only launched with aligned sizes, or add a guarded tail loop similar to the reference implementation.

mistralrs-quant/src/metal_kernels/utils.metal (10)

6-7: Good addition of #pragma once directive.

This is a standard practice to prevent multiple inclusions of this header file, which is especially important for utility headers that might be included in multiple translation units.


616-622: Clean template struct implementation for work distribution.

The WorkPerThread struct elegantly scales work based on data type size, which is a good practice for optimizing GPU workloads.


627-672: Well-structured type limits implementation.

The Limits struct and its specializations provide a clean abstraction for accessing numeric limits of different types, with proper handling for both integral and floating-point types.


678-767: Comprehensive indexing utilities for array traversal.

The indexing utilities provide efficient functions for converting linear indices to memory locations for arrays of different dimensions, which is crucial for GPU kernel performance.


773-847: Well-designed recursive template for multi-dimensional iteration.

The LoopedElemToLoc template provides an elegant solution for traversing multi-dimensional arrays, with specialized implementations for different dimensionalities.


869-897: Numerically stable calculation utilities.

The ceildiv and log1p implementations follow best practices for numerical stability, including proper handling of edge cases.


903-956: Effective SIMD shuffle operations for 64-bit types and booleans.

These implementations properly wrap Metal's SIMD intrinsics, extending them to handle 64-bit integers and boolean types.


957-1087: Comprehensive set of templated operator functors.

These functors provide a clean, uniform interface for arithmetic, comparison, and logical operations, with appropriate specializations for different types.


1090-1107: Efficient power implementation for integral types.

The Power functor's specialization for integral types uses the efficient binary exponentiation algorithm (also known as exponentiation by squaring), which is optimal for computing integer powers.


1153-1159: Useful kernel instantiation macro.

The instantiate_kernel macro provides a convenient way to instantiate templated Metal kernels with specified host names and template parameters.

Comment on lines 13 to 16
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing import for List from the typing module.

The code uses List type annotations, but the typing module that provides List isn't imported.

Add the following import at the top of the file:

import os
import shutil
+from typing import List
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
import os
import shutil
from typing import List
# ... rest of scripts/convert_awq_marlin.py ...
🧰 Tools
🪛 Ruff (0.11.9)

13-13: Undefined name List

(F821)


16-16: Undefined name List

(F821)

🤖 Prompt for AI Agents
In scripts/convert_awq_marlin.py around lines 13 to 16, the code uses the List
type annotation but does not import List from the typing module. To fix this,
add the import statement "from typing import List" at the top of the file to
ensure the List type is recognized.

Comment on lines +46 to +62
async def timed_chat(client: AsyncOpenAI, messages):
"""
Send one chat completion request and return (completion, elapsed_seconds, completion_tokens).
"""
start = time.perf_counter()
completion = await client.chat.completions.create(
model="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
messages=messages,
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
)
elapsed = time.perf_counter() - start
# Safely get number of completion tokens, default to 0 if missing
completion_tokens = getattr(completion.usage, "completion_tokens", 0)
return completion, elapsed, completion_tokens
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Parameterize the timed_chat function

Model name and generation parameters are hardcoded, limiting the benchmark's flexibility for testing different configurations.

-async def timed_chat(client: AsyncOpenAI, messages):
+async def timed_chat(
+    client: AsyncOpenAI, 
+    messages, 
+    model_name="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
+    max_tokens=256,
+    frequency_penalty=1.0,
+    top_p=0.1,
+    temperature=0
+):
     """
     Send one chat completion request and return (completion, elapsed_seconds, completion_tokens).
     """
     start = time.perf_counter()
     completion = await client.chat.completions.create(
-        model="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
+        model=model_name,
         messages=messages,
-        max_tokens=256,
-        frequency_penalty=1.0,
-        top_p=0.1,
-        temperature=0,
+        max_tokens=max_tokens,
+        frequency_penalty=frequency_penalty,
+        top_p=top_p,
+        temperature=temperature,
     )
     elapsed = time.perf_counter() - start
     # Safely get number of completion tokens, default to 0 if missing
     completion_tokens = getattr(completion.usage, "completion_tokens", 0)
     return completion, elapsed, completion_tokens
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def timed_chat(client: AsyncOpenAI, messages):
"""
Send one chat completion request and return (completion, elapsed_seconds, completion_tokens).
"""
start = time.perf_counter()
completion = await client.chat.completions.create(
model="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
messages=messages,
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
)
elapsed = time.perf_counter() - start
# Safely get number of completion tokens, default to 0 if missing
completion_tokens = getattr(completion.usage, "completion_tokens", 0)
return completion, elapsed, completion_tokens
async def timed_chat(
client: AsyncOpenAI,
messages,
model_name="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0
):
"""
Send one chat completion request and return (completion, elapsed_seconds, completion_tokens).
"""
start = time.perf_counter()
completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
top_p=top_p,
temperature=temperature,
)
elapsed = time.perf_counter() - start
# Safely get number of completion tokens, default to 0 if missing
completion_tokens = getattr(completion.usage, "completion_tokens", 0)
return completion, elapsed, completion_tokens
🤖 Prompt for AI Agents
In scripts/bench.py around lines 46 to 62, the timed_chat function has hardcoded
model name and generation parameters, reducing flexibility. Modify the function
signature to accept model name and generation parameters as arguments with
default values, then use these parameters inside the function instead of
hardcoded values. This will allow benchmarking with different configurations
without changing the function code.

Comment on lines +65 to +78
async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str):
"""
Returns list of (completion, elapsed_seconds, completion_tokens).
"""
results = []
base_messages = []
if system_prompt:
base_messages.append({"role": "system", "content": system_prompt})

for _ in range(REQUESTS_PER_USER):
messages = base_messages + [{"role": "user", "content": user_message}]
completion, elapsed, completion_tokens = await timed_chat(client, messages)
results.append((completion, elapsed, completion_tokens))
return results
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling to user_task

The function lacks error handling, which could lead to unhandled exceptions if requests fail.

 async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str):
     """
     Returns list of (completion, elapsed_seconds, completion_tokens).
     """
     results = []
     base_messages = []
     if system_prompt:
         base_messages.append({"role": "system", "content": system_prompt})
 
     for _ in range(REQUESTS_PER_USER):
-        messages = base_messages + [{"role": "user", "content": user_message}]
-        completion, elapsed, completion_tokens = await timed_chat(client, messages)
-        results.append((completion, elapsed, completion_tokens))
+        try:
+            messages = base_messages + [{"role": "user", "content": user_message}]
+            completion, elapsed, completion_tokens = await timed_chat(client, messages)
+            results.append((completion, elapsed, completion_tokens))
+        except Exception as e:
+            print(f"Error during request: {e}")
+            # Add a failed result with zero tokens and high elapsed time to indicate failure
+            results.append((None, 10.0, 0))
     return results
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str):
"""
Returns list of (completion, elapsed_seconds, completion_tokens).
"""
results = []
base_messages = []
if system_prompt:
base_messages.append({"role": "system", "content": system_prompt})
for _ in range(REQUESTS_PER_USER):
messages = base_messages + [{"role": "user", "content": user_message}]
completion, elapsed, completion_tokens = await timed_chat(client, messages)
results.append((completion, elapsed, completion_tokens))
return results
async def user_task(client: AsyncOpenAI, system_prompt: str, user_message: str):
"""
Returns list of (completion, elapsed_seconds, completion_tokens).
"""
results = []
base_messages = []
if system_prompt:
base_messages.append({"role": "system", "content": system_prompt})
for _ in range(REQUESTS_PER_USER):
try:
messages = base_messages + [{"role": "user", "content": user_message}]
completion, elapsed, completion_tokens = await timed_chat(client, messages)
results.append((completion, elapsed, completion_tokens))
except Exception as e:
print(f"Error during request: {e}")
# Add a failed result with zero tokens and high elapsed time to indicate failure
results.append((None, 10.0, 0))
return results
🤖 Prompt for AI Agents
In scripts/bench.py around lines 65 to 78, the user_task function does not
handle errors, risking unhandled exceptions during requests. Add try-except
blocks around the await timed_chat call to catch exceptions, log or handle them
appropriately, and ensure the function continues or fails gracefully without
crashing.

Comment on lines +1578 to 1580
uint simd_gid [[simdgroup_index_in_simdgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Attribute typo breaks compilation

simdgroup_index_in_simdgroup is not a valid Metal shading language attribute – the correct attribute for “SIMD-group index inside the current thread-group” is simdgroup_index_in_threadgroup.

This line will make the entire Metal library fail to compile on Apple tools ≥ Xcode 9.

- uint simd_gid [[simdgroup_index_in_simdgroup]],
+ uint simd_gid [[simdgroup_index_in_threadgroup]],

Please adjust here and in any other newly-added kernels to restore successful compilation.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
uint simd_gid [[simdgroup_index_in_simdgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/quantized.metal at lines 1578 to 1580, the
attribute simdgroup_index_in_simdgroup is incorrect and causes compilation
failure. Replace simdgroup_index_in_simdgroup with the correct attribute
simdgroup_index_in_threadgroup to fix the compilation error. Also, check any
other new kernel code for the same typo and correct it similarly.

Comment on lines +318 to 323
match quantized_filename {
Some(names) => {
let id = quantized_model_id.as_ref().unwrap();
let id = quantized_model_id.unwrap();
let mut files = Vec::new();

for name in names {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

unwrap() may panic when only filenames are supplied

Inside get_model_paths the Some(names) arm unconditionally unwraps quantized_model_id:

let id = quantized_model_id.unwrap();

If the caller provides an explicit quantized_filename list but forgets the corresponding quantized_model_id, the process will panic instead of returning an error.

Consider returning an anyhow::Error instead:

- let id = quantized_model_id.unwrap();
+ let id = quantized_model_id
+     .ok_or_else(|| anyhow::anyhow!("`quantized_model_id` must be supplied when `quantized_filename` is provided"))?;

This keeps the function panic-free and surfaces configuration errors gracefully.

🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/paths.rs around lines 318 to 323, the code
unconditionally unwraps quantized_model_id inside the Some(names) match arm,
which can cause a panic if quantized_model_id is None. Modify the code to check
if quantized_model_id is Some before unwrapping; if it is None, return an
anyhow::Error indicating the missing quantized_model_id instead of panicking.
This change will make the function panic-free and handle configuration errors
gracefully.

Comment on lines +1091 to +1104
impl SortOp for Tensor {
fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
if self.device().is_cpu() || self.device().is_cuda() {
return self.arg_sort_last_dim(true);
}
self.apply_op1_no_bwd(&ArgSort {
axis: axis.to_index(self.shape(), "argsort")?,
})
}

fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
if self.device().is_cpu() || self.device().is_cuda() {
return Ok(self.sort_last_dim(true)?.0);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

fast_argsort_asc / fast_sort_asc ignore the supplied axis when running on CPU/CUDA

Both helpers immediately fall back to arg_sort_last_dim / sort_last_dim, which always act on the last dimension, even if the caller asked for another axis.
This yields silently-wrong results for any multi-dimensional tensor whose last dimension is not the sort axis.

-        if self.device().is_cpu() || self.device().is_cuda() {
-            return self.arg_sort_last_dim(true);
-        }
+        if self.device().is_cpu() || self.device().is_cuda() {
+            // Candle already provides arg_sort()/sort() that respect an axis.
+            // Preserve the public behaviour by forwarding the user-requested axis.
+            let axis = axis.to_index(self.shape(), "argsort")?;
+            return self.arg_sort(axis, true);          // or `self.sort(axis, true)?`
+        }

(The same change is required for fast_sort_asc a few lines below).

Please ensure the accompanying unit-tests cover non-trailing axes so the issue is caught automatically in the future.

🤖 Prompt for AI Agents
In mistralrs-quant/src/utils/ops.rs around lines 1091 to 1104, the methods
fast_argsort_asc and fast_sort_asc ignore the supplied axis parameter when
running on CPU or CUDA by always sorting along the last dimension. To fix this,
modify these methods to use the provided axis for sorting instead of defaulting
to the last dimension, ensuring the axis is converted properly using
axis.to_index with the tensor shape. Additionally, update or add unit tests to
cover sorting on non-last axes to prevent this issue from recurring.

Comment on lines +143 to +156
auto idx = elem_to_loc_2_nd<IdxT>({N * index.x, index.y, index.z}, src_shape,
src_strides, dst_strides, ndim);
if (N == 1) {
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);
idx.x += src_xstride;
idx.y += dst_xstride;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing bounds check in strided gather→gather copies

copy_gg writes

dst[idx.y] = static_cast<U>(src[idx.x]);

inside a loop where both idx.x and idx.y are manually incremented.
Neither value is compared against the logical src/dst sizes, so a malformed shape/stride pair can lead to OOB writes and kernel crashes.

Unless you have stronger invariants on the Rust side, protect the loop:

if (idx.x < src_len && idx.y < dst_len) { … }

Even a single device_assert would greatly improve debuggability.

🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal around lines 143 to 156,
the loop writing to dst and reading from src increments idx.x and idx.y without
bounds checks, risking out-of-bounds memory access. To fix this, add conditional
checks inside the loop to ensure idx.x is less than the source length and idx.y
is less than the destination length before performing the copy. Optionally,
include device_assert statements to catch violations during debugging. This will
prevent OOB writes and improve kernel stability.

Comment on lines +8 to +10
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid signed/unsigned comparisons to prevent subtle out-of-bounds bugs

All inner loops use int i while the limit expressions (index, size, xshape) are uint or int64_t.
Because of the usual arithmetic conversions the compiler silently casts operands back and forth, which can bite once size ≥ 2³¹ or when negative intermediates are introduced.

A minimal change keeps the code safe and helps the optimiser:

-for (int i = 0; i < N && (index + i) < size; ++i) {
+for (uint i = 0u; i < uint(N) && (index + i) < size; ++i) {

Please repeat the same change in every loop highlighted above (copy_s*, copy_v*, copy_g, copy_gg, copy_gg_dynamic).
(The template parameter N is still int; only the loop counter needs to be unsigned.)

Also applies to: 18-20, 29-31, 40-42, 96-99, 152-156, 222-226

🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal at lines 8 to 10 and
similarly at lines 18-20, 29-31, 40-42, 96-99, 152-156, and 222-226, the loop
counter variable is declared as a signed int while the loop limit variables are
unsigned or int64_t, causing unsafe signed/unsigned comparisons. Change the loop
counter type from int to an appropriate unsigned integer type to match the limit
variables, ensuring safe comparisons and preventing potential out-of-bounds
errors. Apply this change consistently to all loops mentioned in the comment.

Comment on lines +85 to +88
auto src_idx = elem_to_loc<IdxT>({N * index.x, index.y, index.z}, src_shape,
src_strides, ndim);
if (N == 1) {
IdxT dst_idx =
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Brace-initialiser may fail to compile on some Metal toolchains

elem_to_loc is called with {N * index.x, index.y, index.z} which is not guaranteed to deduce to a uint3/int3 in MSL.
Several versions of Xcode/metal-clang reject this syntax.

Be explicit to avoid portability issues:

-auto src_idx = elem_to_loc<IdxT>({N * index.x, index.y, index.z},
+auto src_idx = elem_to_loc<IdxT>(uint3(N * index.x, index.y, index.z),
                                  src_shape, src_strides, ndim);

Also applies to: 95-97

🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal around lines 85 to 88 and
also lines 95 to 97, the call to elem_to_loc uses a brace-initializer {N *
index.x, index.y, index.z} which may not compile on some Metal toolchains due to
type deduction issues. To fix this, explicitly construct a uint3 or int3
(matching the expected type) with these values instead of using a
brace-initializer, ensuring compatibility across Metal versions.

Comment on lines +27 to +32
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential 32-bit overflow in 2-D kernels

offset is computed as

auto offset = N * (index.x + grid_dim.x * int64_t(index.y));

index.x and grid_dim.x are uint, so the addition is performed in 32-bit before it is promoted to int64_t by auto.
With large grids (≥ 65 536 × 65 536) this silently wraps and corrupts memory.

Safer version:

-auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
+auto offset = IdxT(N) * (IdxT(index.x) + IdxT(grid_dim.x) * IdxT(index.y));

where IdxT is the same 64-bit type used by your helpers.

Also applies to: 39-42

🤖 Prompt for AI Agents
In mistralrs-quant/src/metal_kernels/copy_impl.metal around lines 27 to 32, the
calculation of offset uses 32-bit arithmetic for the addition of index.x and
grid_dim.x * index.y before converting to 64-bit, which can cause overflow with
large grid sizes. To fix this, cast index.x and grid_dim.x to the 64-bit integer
type (IdxT) before performing the addition and multiplication to ensure all
arithmetic is done in 64-bit and prevent overflow. Apply the same fix to the
similar code block at lines 39 to 42.

@EricLBuehler EricLBuehler merged commit a97de2b into master May 21, 2025
13 checks passed
@EricLBuehler EricLBuehler deleted the fast_sampler branch May 21, 2025 17:12
@EricLBuehler EricLBuehler restored the fast_sampler branch May 22, 2025 01:06
@EricLBuehler EricLBuehler deleted the fast_sampler branch May 22, 2025 01:06
Jeadie added a commit to spiceai/mistral.rs that referenced this pull request Jul 14, 2025
* Fix handling of Metal fused attn head dims (EricLBuehler#1234)

* Fix handling of metal attn head dims

* Fix handling of gemma3 1b when images

* Tweak default for paged attn builder

* Support paged attn for vision model rust api (EricLBuehler#1235)

* [Breaking] Support setting HF cache path (EricLBuehler#1237)

* Add it internally

* Add the apis

* Support tool calling for DeepSeek models (EricLBuehler#1239)

* Support tool calling for deepseek models

* Format

* Fix deepseek

* Server image processing refactor and fixes (EricLBuehler#1244)

* Fix strict gemma3 case

* Accept multiple images in the content array

* Fix multiple images in one array ct

* Add it to the python api

* Typos

* Optimized CUDA RoPE kernels (EricLBuehler#1247)

* Add the kernels

* It works

* Works

* Buulds

* Typo fix (add_speial_tokens to add_special_tokens) (EricLBuehler#1246)

* Fix typo

* Update mistralrs.pyi

* Fixes for UQFF + distributed layers (EricLBuehler#1250)

* Fixes for uqff + distributed layers

* Typo

* Automatic agentic search integration (`web_search_options`) (EricLBuehler#1243)

* Add the tool

* Actually search

* Clippy

* Sort of works

* Remove some debuggers

* tweak

* Add some rules

* Works great

* Tweak 'system' prompt

* Update mistralrs-core/src/search/mod.rs

Co-authored-by: Copilot <[email protected]>

* Typo

* Add it to all the apis

* Add bert model for similarity reranking

* Typos

* Early detection of tools

* Alias max_tokens -> max_completion_tokens too

* Customizable bert model

* Flip the enabler around

* Add docs

* Update readme

* Typo

---------

Co-authored-by: Copilot <[email protected]>

* Format kernels (EricLBuehler#1251)

* Update readme

* Update readme

* Remove test

* Add quantize guards for uqff deserialize (EricLBuehler#1252)

* Refactor cuBLASlt-related code (EricLBuehler#1253)

* Centralize cublaslt into mistralrs-quant

* Use cublaslt in unquant layer

* Use beautiful trait constants for simpler code

* Move tests

* Dispatch to unquant for cublaslt

* Dispatch to unquant for cublaslt

* Fix feature

* Add convert_to_gptq script

* Update deps, bump pyo3 version (EricLBuehler#1259)

* Faster cuda FP8 performance (EricLBuehler#1257)

* Avoid fp8 sync

* Fix dtype

* Rust 1.86 clippy (EricLBuehler#1260)

* Rust 1.86 clippy

* Clippy

* Refactor engine arch (EricLBuehler#1262)

* Refactor engine add_request

* Don't recompile regex

* Clippy

* Revamped LoRA support - removing the Ordering system! (EricLBuehler#1263)

* Play with varbuilder lifetimes

* Merge lora weights

* Clippy

* Lora works

* Support multiple loras

* Cleanup, remove adapter activation

* Complete merge

* Fast Metal-specific quantization method: AFQ (EricLBuehler#1264)

* Add mlx quantized kernels

* Add mlx quantized kernels

* Kernel launcher

* Add AFQ isq quant and dequant

* Some quantmethod things

* Begin to implement the qmm caller

* Clippy

* Much faster

* Cache kernels

* Docs

* Clippy

* Add it to uqff

* Support prequantized models from MLX (EricLBuehler#1265)

* Refactor quantizedconfig

* Support AFQ prequantized

* Update docs

* Update docs

* Automatic ISQ to select fastest & most accurate method (EricLBuehler#1266)

* Automatic isq

* typo

* Doc

* Improved usage metrics (EricLBuehler#1267)

* Fix cuda

* Bump tokio from 1.44.1 to 1.44.2 (EricLBuehler#1270)

Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.44.1 to 1.44.2.
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](tokio-rs/tokio@tokio-1.44.1...tokio-1.44.2)

---
updated-dependencies:
- dependency-name: tokio
  dependency-version: 1.44.2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Gather MM ops in mistralrs-quant (EricLBuehler#1272)

* Update the caller

* Wire things up

* Broadcase for afq gathermm

* Broadcase for afq gathermm

* Clippy

* Improve performance of deepseek models

* Typo fix

* BincountOp not used

* Implement Llama 4! (EricLBuehler#1268)

* Implement Llama 4

* Implement the main changes for the text model

* Make chunked mask

* Wire things up

* Add some EP

* Initial sketch of inputs processor

* Runs

* Progress

* all reduce moes

* It works!

* Some cleanup

* Faster moe block

* Add device map

* Make chunked matrix

* Fully working now!

* Reactivate cublaslt

* Fix shared mlp cublaslt

* Refactor to packed experts

* Complete merge

* It is a normal model now

* Fixes

* Set device for moe

* ISQ fixes

* Much faster sort kernel

* Faster loading!

* Faster loading!

* Fp8 cpu copy ops in candle backend

* Add the vision model

* Add mmproj layer

* Actually merge the inputs

* Sketch most of the image processor

* Add the rest of the image processor

* Implement the whole processor

* Add the loader

* Some fixes

* A batch of fixes

* Some fixes

* tmp

* Actually support isq

* Ok it works a bit

* Fix norm device

* It works

* A bit cleaner

* Support residul tensors

* Remove text loader

* Implement the device mapping system

* Fix auto device map

* Add examples

* Add model card

* Typo

* Remove superflous logging

* Fixes for Llama 4 UQFF loading (EricLBuehler#1275)

* Support sharding for UQFF (EricLBuehler#1276)

* Serialize sharded uqff files

* Loading

* Fix base64

* Fix bug for group-topk (group_limited_greedy) in deepseek models (EricLBuehler#1278)

* Support the DeepCoder model (EricLBuehler#1279)

* Add faq for metal not found

* Improved PagedAttn scheduling accuracy (EricLBuehler#1282)

* Scheduler ops by reference

* Ensure scheduler gets correct prompts

* Fix cuda build for copy_blocks

* Fixes for scheduling image seqs with pagedattn (EricLBuehler#1283)

* update to llguidance 0.7.16 (EricLBuehler#1284)

* update llguidance to 0.7.16 from crates.io; use ParserFactory

* add lark_llg.py example

* use new llguidance::Matcher APIs

* rework spec-decoding with llg

* more work on spec sampling

* check for parser stop

* fix clippy

* remove unneeded rollback

* update build_llg_factory to return Result

* Update dependencies (EricLBuehler#1286)

* Much faster image inputs processing (EricLBuehler#1289)

* Add more SDPA head dims for much faster SigLIP (EricLBuehler#1290)

* More sdpa head dims, faster vision models

* Move nonzero to above for faster metal synch

* Doc

* Update valid head dims

* Show throughput in interactive mode (EricLBuehler#1291)

* Update interactive mode throughput stats

* Accurate prompt t/s

* Accurate prompt t/s for usage

* Unify bitwise operations (EricLBuehler#1288)

* Unify bitwise ops

* Tests pass

* Fix cuda build

* Clippy

* Multimodal prefix caching support! (EricLBuehler#1298)

* Initial progress

* Support vision prefix caching

* Update docs

* Add multimodal data abstraction

* Interactive mode improvements (EricLBuehler#1299)

* More ergonomic image url parsing

* Add option to clear

* Add the Qwen 3 and Qwen 3 MoE models! (EricLBuehler#1285)

* Add qwen3 model

* Add enable_thinking

* Add initial qwen3 moe

* Add the moe model

* Format

* Fix order of norm

* Fix expert shapes

* Fix reverse

* Fix norm device for isq

* Fix nonzero when no nonzero

* Moe model runs

* Working qwen3 moe

* Add metal fp8 blockwise dequant

* Clean

* Typo

* Enable tool calling

* Streamlined ux

* Add some examples

* Add docs

* Fix dead link

* Remove interactive mode max_len

* Update QWEN3.md

* Hotfix for vision mode clear

* Revamped and streaming web search support (EricLBuehler#1301)

* Streaming web search

* Refactor a bit

* More refactoring

* Add some logging, parallelize some things

* Allow url

* Suppress warning, allow multi-turn searching

* Batch compute_similarities

* Cap content len

* Typos

* Doc

* Handle vision messages or different tool call prefixes (EricLBuehler#1302)

* Fix cuda

* Tune web search budget

* Simplify prefix cacher (EricLBuehler#1305)

* Use rustyline to handle non-ascii in interactive mode (EricLBuehler#1306)

The io::stdin().read_line() cannot handle non-ascii input, which caused
crash when use backspace to delete non-ascii characters.

Introduce rustyline to the interactive mode to solve the problem. Plus
it can bring more editing features in the future.

Close EricLBuehler#1140

* Add more tools for automatic search (EricLBuehler#1307)

* Add interactive mode history

* Add a website extraction tool

* Pass toks by reference

* Optimize prompt chunking

* Fix CPU hogging in interactive mode (EricLBuehler#1309)

The log enabler should be checked after the sleep instead of a busy
loop checking.

Since the interactive mode always disables the token speed logger, 100%
CPU was taken by this loop always.

* Add Metal precompilation support  (EricLBuehler#1311)

* Add metal precompilation for paged attn

* Add for mistralrs-quant

* Better constructor

* Dont always build

* Fix name for paged attn rebuild

* Reduce thrashing of Metal autorelease (EricLBuehler#1313)

* Reduce calls to autorelease

* Optimize clone_in_cache

* Refactor float8

* make `AdapterPaths` and `LoraAdapterPaths` public (EricLBuehler#1314)

Make `AdapterPaths` and `LoraAdapterPaths` public so `LocalModelPaths`
can be constructed outside of `mistralrs-core`.

* Refactor KV cache manager (EricLBuehler#1315)

* Refactor kv cache

* Refactor caches

* Fix some overflows

* Add `Audio` and `Speech` model categories (EricLBuehler#1317)

* add `Audio` to `ModelCategory`

* add `Speech` to `ModelCategory`

* fix to go back to PartialEq having an exhaustiveness check

* Remove has_conv2d from vision model API (EricLBuehler#1318)

* Unified/automatic flash attention enabler (EricLBuehler#1319)

* Remove from sdpa params

* Fix errors

* No warnings

* Log

* Clippy

* Fix cublaslt 4d mask (EricLBuehler#1320)

* Fix cublaslt 4d mask

* Clippy

* Keep caches on gpu

* Qwen VL models fixes (EricLBuehler#1322)

* Add some defaults

* Fix

* Fix one thing

* 2.5 vl works

* Use caching again

* Fix v2

* Move index inside loop

* Offset in ropeidx

* Default support for vision prefix caching is false

* Fixes for all vision models (EricLBuehler#1323)

* Fix phi input processor?

* Fix phi input processor

* Handle no_prefix_cache from pipeline

* Phi models confirmed 👍

* Fixed for phi inputs processors

* Fixed for phi4

* Llama 3 confirmed 😀

* Mistral 3 confirmed 😃

* Idefics 2/3 fixes

* Some fixes

* Remove unsafety

* Improved+faster LRU prefix cacher (EricLBuehler#1321)

* Show TTFT

* Use LRU prefix cacher

* Faster prefix cacher

* Inplace ISQ support and default to mmap (EricLBuehler#1277)

* Initial impl of immediate isq

* Immediate isq -> !loading_isq

* Varbuiler utils always using mmap!

* Log

* Add for packed experts

* Afq without copy

* Clarify

* Clippy

* Apple immediate isq

* Better logic for loading_isq

* Support showing ttft

* Rename

* Shared quantize guard

* Parallel progress bar

* Parallel loading for progress bars

* Actual ISQ support

* Conditional parallelism for NiceProgressBar

* Use conditional iterator

* Warn once

* Predicate for applying immediate isq

* Allow parallel

* Remove debug print

* Remove debug print

* Remove debug print

* Fix typos (EricLBuehler#1329)

* Fix Idefics 3 arch chat templating (EricLBuehler#1330)

* Update inputs merger

* Fix

* Better warning

* Better warning

* Better warning

* Nonzero ahead of time

* No f32

* Clippy

* Optimize get_logprobs

* Fix packed experts

* Update masking

* Use Sdpa in idefics3

* QuantMethod in idefics3 vision

* Remove a .contiguous

* Remove two space from PR comment (EricLBuehler#1331)

* Add automatic vision loader type (EricLBuehler#1332)

* Add automatic vision loader

* Remove references to --arch

* Update examples

* Add the Dia 1.6b TTS model! (EricLBuehler#1304)

* Add loading

* Add rope, mlp, most of attn

* Add encoder + encoder layer, decoder layer forwards

* Add decoder forwards

* Add prepare_audio_prompt

* prepare_generation mostly done

* Add a proper dia kvcache

* Add most of decoder_step

* Add the sampler

* Add the generation loop

* Wire things up

* Add speech pipeline

* Fixes

* Loads

* Some fixes

* f32

* Some progress

* Ok it runs upto dac decoding

* Add dac part loading

* Loads and runs at least

* Remove encodec

* Debugging

* Debugging

* Huh

* Complete merge

* Interactive

* Confirmed dac works at least

* Looks like encoder works

* Much progress

* Hmm

* Sampling

* Almost there

* Sampler

* Sampler

* Bf16 support

* Response

* Use it in interactive mode

* Fix oneshot

* Add openai api

* Add openai api

* Refactor loading

* Use naive sdpa for inplace

* Factor out

* Clippy

* Clippy

* Config

* Refactor config

* Metal clippy

* Fix t/s

* ISQ support

* Some fixes, nits

* Fix cuda

* Clippy

* Inhibit cublaslt for cuda

* Add server example

* Add python example

* Add rust api

* Add docs

* Update config.toml

* Fix .pyi

* Update readme

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* update `llguidance` to `0.7.20` (EricLBuehler#1334)

Update `llguidance` from `0.7.16` to `0.7.20` so that it has guidance-ai/llguidance#172 which is a fix for building on GCC 15.

* Add model category <> messages check (EricLBuehler#1335)

* Verify model category matches the messages

* Add vision chat

* Fixes

* Add element-wise normalization check (EricLBuehler#1340)

* Fix streaming example print statement (EricLBuehler#1339)

* Fix normalization formula in comment (EricLBuehler#1338)

* Fix image_to_pixels to handle non-RGB images (EricLBuehler#1337)

* Fix typo in expect messages (EricLBuehler#1342)

* Don't use mmap on cuda (EricLBuehler#1336)

* No mmap on cuda

* Simplify streaming tool call logic

* Remove debug

* Support AWQ format models (EricLBuehler#1350)

* Support AWQ format models

* Clippy fix

* Fix uqff dummy layer ISQ application (EricLBuehler#1351)

* Disable immediate isq if write_uqff (EricLBuehler#1352)

* Fixes for UQFF loading on CUDA, ISQ pack factor (EricLBuehler#1354)

* Fix logic for uqff on cuda

* Updated pack_factor

* Refactor Option references for model paths (EricLBuehler#1347)

* refactor: use Option refs in model path helpers

* Format

* Add a script for server benchmarking (EricLBuehler#1355)

* Serde alias

* Fix

* Update for tie_word_embeddings

* Print running/waiting

* 30 users

* Update num_users

* Update dummy paged attn

* Optimized Metal qmv_fast path (EricLBuehler#1356)

* Compile with lto

* Tweak profiles

* New, fast sampler for Metal! (EricLBuehler#1327)

* Show TTFT

* Use LRU prefix cacher

* Faster prefix cacher

* A bit of gpu sampling

* Minp but cpu for now

* Metal fast cumsum impl

* Sampling with fast topp kernel

* Hmm not perfect

* Add metal sort kernels

* Tmp

* Add single block sort

* Add most of multi block sort, just need copy op

* Add copy kernels

* Expose kernels

* Add a test

* Ok it works

* Structure things

* Add caching

* Rename

* Cpu is default

* CUDA case

* Topk

* Refactor Option references for model paths (EricLBuehler#1347)

* refactor: use Option refs in model path helpers

* Format

* Add a script for server benchmarking (EricLBuehler#1355)

* Serde alias

* Fix

* Update for tie_word_embeddings

* Print running/waiting

* 30 users

* Update num_users

* Update dummy paged attn

* Optimized Metal qmv_fast path (EricLBuehler#1356)

* Compile with lto

* Tweak profiles

* Fix topk

* Penalties

* Add logits processor, clippy fixes

* Fix chat port

* Remove warning

* Fix chat port

* Fix metal parallel sampling (EricLBuehler#1357)

* Cpu if parallel for now

* Tweak bench script

* Add immediate isq predicates for qwen3 (EricLBuehler#1358)

* Add immediate isq predicates for qwen3

* Fix parsing of "parse_isq_value" depedent of device

* Typo

* Fix gemma3 logging

* Regressions fixes (EricLBuehler#1359)

* Fix regression for mmap

* Revert EricLBuehler#1321

* Refactored matching_cache impl

* Clippy

* Revamped and smaller readme (EricLBuehler#1360)

* Expandable detail sections

* Refactor using derivative model

* Tweak quick examples

* Update llama

* Update llama

* Supported accelerators is a table

* Update installation guides

* Tweak apis

* Remove --port in quick examples

* Add demo gif

* Add gif in readme

* Update demo gif

* Update demo gif

* Update demo gif

* Add gif in readme

* Add gif in readme

* Add a web chat app! (EricLBuehler#1362)

* Initial

* Markdown

* Copy code

* Add model loading sidebar

* Support vision models

* Tweak isq

* Links go to another page

* Clear when switch model

* Fix html tags

* Add image support!

* More then one images

* Fix

* Improved textarea

* Tab for switching between vision and text

* No paged attn for now

* Prettier format

* Multiple models at once

* Better switching, clearing ability

* Mobile support

* Inline markdown parser

* Update examples

* Typos

* Support specifying isq

* Fix mobile

* Fixes

* Fix button on mobile

* Image height is capped

* Thumbnail

* Fix rotating kv cache edge case

* Add drag and drop for images

* Small things

* Sidebar is frozen now

* Better listner

* Add readme

* Tweak readme

* Add chat history support to web chat app (EricLBuehler#1363)

* Add chat history

* Support renaming

* Start immediately with new chat

* Add timestamp

* Prettier chat list

* Style

* Delete chat

* Fix copy button

* Fix markdown rendering

* Store things in cache

* Store things in cache

* Refactor web chat, fix multichat image restore (EricLBuehler#1364)

* Fix multichat image restoration.

* Clippy

* Refactor

* Refactor frontent

* Fix repeated immediate isq init (EricLBuehler#1365)

* Add images_ref

* Add debug impl

* Fix the bug

* Tweak style of buttons

* Add a spinner

* Move spinner

* Tweak emoji

* Add gif

* Tweak initial gif

* Include vision tower tensors in Mistral3 UQFF (EricLBuehler#1366)

* Fix mistral 3 uqff resitdual tensors for vision

* Rolling shard creation for uqff files (EricLBuehler#1367)

* Fix occasional unstability during isq of afq (EricLBuehler#1368)

* Fix unstability during isq of afq

* Clippy

* Fix web chat installation

* Support web chat file uploading (EricLBuehler#1370)

* Web chat fixes

* Fix thumbnail in message, reuse blank chat

* Add file uploading support

* Fix scroll

* Allowed extensions

* Preserve files as literals

* Support multiple clients

* Add a stop button

* New cache dir

* New cache dir

* Fix

* Refactor

* Update readme

* Tweak drag-and-drop css

* Add speech generation support to the web chat! (EricLBuehler#1373)

* Initial speech gen support for web chat

* Tweak ui

* Update docs

* Prefix caching for PagedAttention! (EricLBuehler#1369)

* Exposing some things for logical token blocks

* Prefix cache manager has the scheduler

* Refactor

* Get logical and physical blocks into the prefix cacher

* Hash and cache

* Pass physical block prefill

* Allocation of prefilled block tables

* Temp

* Dont always use 2

* Hmm

* Hmm

* It mostly works

* Increment refcount

* Support images!

* Add to dummy paged attn

* Fix some clippy

* Clippy

* More checks

* Include EricLBuehler#1371, closes EricLBuehler#1371

* Typos

* Update docs

* Metal PagedAttention accuracy improvements (EricLBuehler#1374)

* Fix subtle bug

* Fix half sum bug

* Format metal paged attention

* Handle images in paged attn scheduler (EricLBuehler#1375)

* Include schemas needed for chatcompletions endpoint (EricLBuehler#1353)

* EricLBuehler#1326: WIP include schemas needed for chat completions endpoint

 Conflicts:
	Cargo.lock
	mistralrs-server/src/main.rs

* EricLBuehler#1326: WIP define utoipa as a workspace dep since core and server both need it

* EricLBuehler#1326: first draft of handling schemas that use Either

* EricLBuehler#1326: first draft of handling schema for Grammar

* EricLBuehler#1326: Add in other endpoints to API docs.

* EricLBuehler#1326: Adjust code comments

* EricLBuehler#1326: Implement coderabbitai suggestions

- EricLBuehler#1353 (review)
- EricLBuehler#1353 (comment)

* Fix constraints with metal sampler

* Revert EricLBuehler#1375

* Fix case where prefix cacher returns no toks (EricLBuehler#1377)

* Fix AFQ UQFF serialization

* Faster UQFF serialization (EricLBuehler#1379)

* Faster UQFF serialization

* Fix uqff gemma3

* Improve gemma3 auto loader names

* UQFF creation for AFQ on CPU support (EricLBuehler#1380)

* Add afq cpu quantize/dequantize

* Clippy

* Improved device for afq quantize

* Improved dtype handling for cpu afq (de)quantize

* Improved generate_uqff_card

* Add fused CPU attention kernel! (EricLBuehler#1382)

* Working

* Fix warnings

* Allow mask

* Support bf16, f16

* Handle striding

* Parallelized

* Add initial vector flash attn

* Avoid repeated allocations

* Tiled kv

* Apply some clippy

* Some small fixes

* Chunked vec_dot

* Clipy

* Use T::zero

* Refactor attention backends (EricLBuehler#1384)

* Refactor attention code

* Refactor attention code

* Move into backends

* Set macOS thread affinity for CPU attn (EricLBuehler#1385)

* Use lazylock

* Format

* Fix metal warn build

* Faster Qwen 3 MoE support on Metal (EricLBuehler#1387)

* Fix load

* Use afq gather qmm

* Well it runs

* It works

* Polish

* Fast and slow options

* Remove quantized.rs

* Polish some more

* Refactor

* Add isq

* Update load in parallel

* Support fp8

* Refactor for FusedExperts

* Clippy

* Handle pack factor when loading prequantized models

* Use f32 only in moe

* Avoid using f32 so much

* Avoid using f32 so much

* Fix PagedAttention block leaks (EricLBuehler#1388)

* Warn and ignore if ignored

* Fix a block allocation leak

* Update bench.py

* Fix double free in block engine

* Do not apply ISQ if loading a prequantized model

* Fix cuda build again (EricLBuehler#1389)

* Fix cuda build

* Fix

* Format

* Fixes for cuda docker

* Update dockerfiles

* Bump version to 0.6.0 (EricLBuehler#1390)

* Bump version to 0.6.0

* Remove lower_level api

* Make a static dir

* Update deps

* Fix routing for static handler in web chat

* Fewer .contiguous calls for qwen3 moe (EricLBuehler#1391)

* Allow speech models to accept batched inputs (EricLBuehler#1393)

* Allow speech models to accept batched inputs

* Clippy

* Ring distributed backend for heterogeneous TP (EricLBuehler#1238)

* Begin work on ring distributed backend for Metal

* Add the actual ring functionality

* It loads and kind of runs

* It works

* Optimize buffer allocation

* Avoid copy

* It works

* Add allgather

* Fix load

* Ping-pong

* Small things

* Add config json

* Allow different ip address

* Read config once

* Read config when appropriate

* Replicate requests

* Small fix

* Fix small compat with openai

* Clippy

* Update docs

* Add deepseek tool calling chat template

* Add auto loader for vision/text detection! (EricLBuehler#1402)

* Add auto loader for vision/text detection

* Build fixes

* Add model loader

* Update docs

* Format

* Create Mistral.rs Server Core Lib: `mistralrs-server-core` (EricLBuehler#1346)

* First draft of exposing mistral server routes as lib

* make arg struct fields pub

* Take base path so utoipa swagger route can properly redirect

* Expose swagger routes and make it configurable

* Add base path option for swagger docs

* More work on modularizing mistralrs server

* Sync fork (+1 squashed commit)
Squashed commits:
[169ae9e] Sync fork

* Adjust fn params to use refs / individual params instead of args

* Start breaking down controller actions into smaller pieces

* Continue refactoring

* Make mods pub so they can be used outside crate

* Allow chat completion streamer to take a callback so that you can get the complete response when finished

WIP (+3 squashed commits)
Squashed commits:
[0061d87] WIP
[c484d56] WIP
[16f8a60] WIP

* Sync fork

* Adjust callback type

* Remove throughput_log arg that was removed in 26afcc3

* Implement defaults for Args (and use for Clap)

* Small code formatting tweaks

* Rename callback to match SSE event and code clean up

* Sync fork

* WIP: first very rough draft of server core builder. Doesn't meet parity with old functional approach yet (slower / unstable?).

* Clean up (+4 squashed commits)
Squashed commits:
[e1cff387] Sync fork
[d8301025] WIP debugging
[1ea9f8c8] Sync fork
[4fe28cf5] WIP: debug function

* WIP server core builders

* Code clean up

* Add on_chunk callback

* Code clean up

* First draft of creating version of mistral-server that uses server-core

Code clean up (+1 squashed commit)
Squashed commits:
[adea1693]

* Sync fork

* Add helper methods to builder to make optional args more ergonomic (since .build validates params)

* Start adding docs

* Start cleaning up crates deps

* Example commit of mistral-server with implementing server-core

* Start addressing CodeRabbit feedback

* Fix comment typo

* Tweak doc blocks

* - Update type alias naming for clarity (MistralRs instead of Mistral)
- CodeRabbit, don't use eprintln for lib (use trace)
- Allow buffer size to be passed in and default to Constant
- Allow router body limit to be passed in and default to Constant
- Update doc examples

* Typo

* Address CoderRabbitAI feedback

* Support linear rope for llama3 (EricLBuehler#1408)

* Hotfix for loading

* Fix vllama4 uqff loading (EricLBuehler#1409)

* Fix vllama4 uqff loading

* Fix regex

* Fix regex

* Maybe a fix

* Gracefully handle receiver disconnects (EricLBuehler#1410)

* Handle receiver disconnects

* Format

* Fix Qwen3 MoE device mapping irregularities (EricLBuehler#1411)

* Fix bias

* Fix lm_head packing case

* Account for gate

* Fix head dim

* Fix interactive mode URL parsing (EricLBuehler#1412)

* fix url regex in vision interactive mode

* Fix regex

* Clippy

* Refactor auto device map (EricLBuehler#1413)

* Refactor auto device map

* Refactor a bit more

* Clippy

* Enable runtime sampling tweaks in interactive mode (EricLBuehler#1414)

* Document runtime sampling commands

* Fix readme

* Tweak

* Bounds checking

* Tweak temp bounds

* Send streaming tokens every time

* Gumbel sampling for fast sampler (EricLBuehler#1416)

* Improved handling for initialize_logging

* Improved CPU flash attention accuracy & performance (EricLBuehler#1417)

* Downcast correctly

* Operate internally in f32

* Avoid some casts and striding

* Prefetch

* Provide chat_templates to container users (EricLBuehler#1419)

Models often come without chat templates requiring mapping them
from the source repository into a container for access by the
mistralrs-server.

Copy the templates from the build tree into the root of the image
to permit use via `--chat-template /chat_templates/something.json`

TODO:
  With the increase in quantized models and support for other
formats, the initial benchmark run during model load can be used
to qualify/select existing chat templates embedded into the binary
for models which do not come with any (to include output of the
functional failures in each test allowing users to modify the
ones already provided correctly to suit the model being loaded).

Co-authored-by: RageLtMan <rageltman [at] sempervictus>

* Faster cpu flash attn (EricLBuehler#1418)

* Faster cpu flash attn

* Prefetch

* Clippy

* Add some tests

* Add softcap tests

* Fix test_parse_image_url test

* Update tests

* Update tests

* Web search improvements (bm25, web chat) (EricLBuehler#1420)

* Fix web search blocking case

* Web search support in web chat

* Tweak ui

* Support fallback to bm25

* Clippy

* Reinject descriptions

* Propely handle consecutive searches (EricLBuehler#1421)

* Update extraction tool reinjection

* Looped

* Update docs (EricLBuehler#1422)

- lib.rs: clean up example var names and match logging change from EricLBuehler@201d6be
- server_builder: fix typo
- READMEs: link to crate docs

* Better tool call detection logic (EricLBuehler#1424)

* Add web search hook callbacks (EricLBuehler#1426)

* feat: add customizable search hook

* Move to builder

* Update docs

* Fix CUDA context switching, bind thread on CudaStorage drop (EricLBuehler#1428)

* Add CUDA context helper and use in Llama forward

* No flashparams?

* working

* Tweak

* Update to use dep

* conditionally build flash attention inputs (EricLBuehler#1429)

* Add AGENTS.md (EricLBuehler#1430)

* Support Qwen3 GGUF model (EricLBuehler#1432)

* Support QWen3 GGUF model

* Clippy fix

* cargo fmt

* Improved paged attn prefix caching (EricLBuehler#1434)

* Improved paged attn prefix caching

* Disable

* Clippy

* Temporary fix for qwen3 gguf tokenizer (EricLBuehler#1433)

* Temporary fix for qwen3 gguf tokenizer

* Typo fix

* Add tool callback support (EricLBuehler#1427)

* Add tool callback support

* Fixes

* Support named tool callbacks

* Update examples

* Update docs

* Clippy

* Centralize crate dependencies (EricLBuehler#1438)

* chore: centralize dependencies

* Format

* Fix bug in tokenizer created with gguf metadata (EricLBuehler#1440)

* Fix bug in tokenizer created with gguf metadata

* Clippy fix

* Update deps (EricLBuehler#1441)

* Small things

* Update deps

* Update deps

* Update breaking changes

* Doc fixes (EricLBuehler#1442)

* Mention uqff_maker

* Downgrade rustyline 16.0.0 -> 15.0.0 (EricLBuehler#1444)

* Add max_completion_tokens alias for server (EricLBuehler#1451)

* Audio input support (Phi 4 multimodal) (EricLBuehler#1448)

* Deps

* Add conformer

* Nemo loading

* Position embeds

* Load t5 attn bias

* Attn and feed forward

* Add conv module and glu pointwise

* Implement relative attn bias

* Add the forward methods

* Add encoder embedding

* Fix oproj

* Some loading

* Conformer loads!

* Fully loading speech stack

* Merger

* Dont need that

* First pass at audio processing

* Read samples

* Optional

* Small loading fix

* Runs but not correct yet

* Improved audio processing?

* Works with this

* Fix t5 attn bias

* It works!

* Comment

* Use some other crates

* Clippy

* Allow bf16 on metal

* Add prefix_audio

* Remove unused

* Typo

* User specified

* Add audio url parsing

* AudioProjectionMode -> InputMode

* Audio prefix caching

* Fix bug in audio prefix caching

* Support both at the same time!

* Tweak logging

* Support stereo

* Add mistralrs-audio

* Support batching

* Add server and rust api example

* Add python api

* Fix add_multimodal_message

* Fix unfold for conformer

* Streaming example

* Add web chat support

* Add modalities registry

* Fix offline cache issue for gguf models (EricLBuehler#1452)

* Add MCP server endpoints (EricLBuehler#1453)

* feat(server): add MCP server support

* Add mcp docs

* Add handle_list_tools_request

* Better launch, tool handling

* Tmp state

* Ok works

* Handle modalities

* Update docs

* Add ping

* Tweak temperature bounds, args

* MCP documentation pass (EricLBuehler#1455)

* Fix table

* Update mcp docs

* Improve readme header

* Improve readme header

* Integrate an MCP client (EricLBuehler#1456)

* Add builtin mcp client

* Use async loader

* Add headers

* Handle sse

* More flexible search request

* Add tool callbacks with tools, for mcp

* Add bearer token support

* Add websocket support

* Update docs

* Add python api

* Clippy

* Add http api, docs

* Tests pass

* Make these configs actually work

* Add docs

* Make mistralrs-mcp

* Refactor examples

* Update examples

* Add defaults

* Add defaults

* Add defaults

* Update docs

* Improved docs

* Add -y to npx usages

* Even better examples

* Update generate_wheels

* Update generate_wheels

* Update generate_wheels

* Fix Dockerfile.cuda-all

* Improve automatic tool call (EricLBuehler#1460)

* Improved auto tool call

* Add logging

* chore: `Dockerfile.cuda-all` configurable threads (EricLBuehler#1458)

* chore: `Dockerfile.cuda-all` - Merge `RUN` for `apt-get install` (EricLBuehler#1459)

* Add fallback definition for isnan (EricLBuehler#1463)

* chore: `Dockerfile` - Drop runtime rayon thread ENV (EricLBuehler#1465)

* chore: Dockerfile - Remove rayon threads env

* chore: Dockerfile - Improve formatting for `apt-get`

* Remove duplicate calls for api_dir_list (EricLBuehler#1474)

* Remove duplicate calls for api_dir_list

* Support local cache for api_dir_list

* Fix home folder for metal

* Capitalized

* Fix transient pyo3 dep (EricLBuehler#1478)

Co-authored-by: Eric Buehler <[email protected]>

* Fix objc dep with non macos (EricLBuehler#1480)

* Fix phi 3/4 + nccl issue (EricLBuehler#1481)

* Fix log

* Fix n kv heads

* Fix phi3.5 moe (EricLBuehler#1482)

* Fix phi3.5 moe accum device

* Fix again

* Fix again

* Support GLM4 model! (EricLBuehler#1437)

* Support GLM4 model

* Mention GLM4 model in ReadMe

* glm4 type hint

* Typo fix

* Fix unsupported chat_template function

* Clippy fix

* Refactor distributed backend (EricLBuehler#1484)

* Refactor distributed backend, check power of 2

* Fix compilation

* Cap metal paged attn kv allocation (EricLBuehler#1485)

* Better paged attn metal cap (EricLBuehler#1486)

* Better paged attn metal cap

* Small fix

* Comment

* Small fix

* Refactor

* Server core: consolidate and unify route handlers and API surface (EricLBuehler#1423)

* Start working on consolidating completion and chat_completion underlying implementations

* Move response channel to util mod for now (since it's used with streaming and non streaming)

* More work on consolidating completions and chat completions

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* Update docs and restrict completion core visibility

* CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this

* Use consistent var name for completions mod

* Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub
Make lib.rs example compile checked and update example

* Code formatting

* Typo

* Sync fork

* Sync fork

* Docs example fix

* Support qwen3 gguf (EricLBuehler#1488)

* Add qwen3 gguf

* Template fixup

* Make bos/eos token IDs optional (EricLBuehler#1493)

* Remove python deps from CUDA dockerfiles (EricLBuehler#1487)

* Handle noncontiguous v in naive_sdpa (EricLBuehler#1499)

Co-authored-by: Eric Buehler <[email protected]>

* Server Core: refactor Paged Attention configuration (EricLBuehler#1500)

* Use StorageModePrivate for Metal PA kv cache (EricLBuehler#1506)

* Fix OpenAI stream: emit field in tool-call deltas for schema compliance (EricLBuehler#1507)

* FP8 KV-cache quantization for PagedAttention (EricLBuehler#1400)

* Add most of paged attn kv quant

* It builds a bit

* All the functionality at least

* Small fix

* Add a scale

* Fix bf16 usage

* Make k_v_scale optional

* Collector

* Tweak collection

* Refactor

* Add to apis

* Add cuda impl

* Fix compilation

* Fixes

* Handle ENABLE_FP8

* Format

* Tweak

* Fix scaled_convert usage

* Fix cache_t size

* Fixed scale collection

* Actual fix

* Fix fp8 for CC<8

* Fix the usual String != &str bit (EricLBuehler#1483)

Co-authored-by: RageLtMan <rageltman [at] sempervictus>

* chore: `Dockerfile` - Drop runtime rayon thread ENV (EricLBuehler#1465)

* chore: Dockerfile - Remove rayon threads env

* chore: Dockerfile - Improve formatting for `apt-get`

* Remove duplicate calls for api_dir_list (EricLBuehler#1474)

* Remove duplicate calls for api_dir_list

* Support local cache for api_dir_list

* Fix home folder for metal

* Capitalized

* Fix transient pyo3 dep (EricLBuehler#1478)

Co-authored-by: Eric Buehler <[email protected]>

* Fix objc dep with non macos (EricLBuehler#1480)

* Fix phi 3/4 + nccl issue (EricLBuehler#1481)

* Fix log

* Fix n kv heads

* Fix phi3.5 moe (EricLBuehler#1482)

* Fix phi3.5 moe accum device

* Fix again

* Fix again

* Support GLM4 model! (EricLBuehler#1437)

* Support GLM4 model

* Mention GLM4 model in ReadMe

* glm4 type hint

* Typo fix

* Fix unsupported chat_template function

* Clippy fix

* Refactor distributed backend (EricLBuehler#1484)

* Refactor distributed backend, check power of 2

* Fix compilation

* Cap metal paged attn kv allocation (EricLBuehler#1485)

* Better paged attn metal cap (EricLBuehler#1486)

* Better paged attn metal cap

* Small fix

* Comment

* Small fix

* Refactor

* Server core: consolidate and unify route handlers and API surface (EricLBuehler#1423)

* Start working on consolidating completion and chat_completion underlying implementations

* Move response channel to util mod for now (since it's used with streaming and non streaming)

* More work on consolidating completions and chat completions

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* Update docs and restrict completion core visibility

* CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this

* Use consistent var name for completions mod

* Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub
Make lib.rs example compile checked and update example

* Code formatting

* Typo

* Sync fork

* Sync fork

* Docs example fix

* Support qwen3 gguf (EricLBuehler#1488)

* Add qwen3 gguf

* Template fixup

* Make bos/eos token IDs optional (EricLBuehler#1493)

* Remove python deps from CUDA dockerfiles (EricLBuehler#1487)

* Handle USE_FP8 for cuda

* Fix cuda warn

* Add readme

* Saturating sub in sequence state

---------

Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: RageLtMan <[email protected]>
Co-authored-by: Brennan Kinney <[email protected]>
Co-authored-by: Guoqing Bao <[email protected]>
Co-authored-by: Matthew Haynes <[email protected]>

* Validate model name in OpenAI API (EricLBuehler#1509)

* Validate model name in openai api

* Add docs, allow 'ignore'

* Updated examples for EricLBuehler#1509

* Fix mcp import in doc string (EricLBuehler#1510)

* Add multi-model support! (EricLBuehler#1512)

* Refactor MistralRs

* Working multi-model!

* Add mutli-model docs initially

* Update mistralrs-pyo3, mistralrs-bench, mistralrs

* Update apis for consistency

* API tweaks

* Logging tweaks

* Add examples, tweak cli

* Clearer pipeline id

* Fix config key semantics

* Format and clippy

* Tweak logging, fix example

* Clippy refactor

* Update examples

* Remove unused multi model docs

* Replace 'ignore' with 'default'

* Update docs

* Add stars label to readme (EricLBuehler#1513)

* Add CLAUDE.md

* Handle base_model.model case in lora (EricLBuehler#1514)

* Add thread_local! for engine-specific const/static (EricLBuehler#1517)

* Fix MCP doc test (EricLBuehler#1511)

* Allow disabling metal precompilation (EricLBuehler#1518)

* Allow disabling metal precompilation

* Simple preprocessor

* Simple docs

---------

Co-authored-by: Eric Buehler <[email protected]>

* Rust 1.88 clippy (EricLBuehler#1522)

* Rust 1.88 clippy

* Format

* Fix cuda warnings (EricLBuehler#1526)

* Avoid panic decoding tokens on error (EricLBuehler#1527)

* Split Marlin and Paged Attention kernels for faster build (EricLBuehler#1525)

* Split Marlin and Paged Attention kernels for faster build

* Typo fix

* chore: update llguidance (EricLBuehler#1535)

* chore: update llguidance

* chore: remove unused import

* Add the SmolLM3 model! (EricLBuehler#1501)

* Add model

* Update loader

* Fix llama config usage

* Docs

* Fix config no_rope_layers

* Fix tie_word_embeddings default

* Add chat template

* Embed the chat templates

* Fix embedding template

* enable_thinking default true

* Update examples

* XML tools for smollm3

* Add smollm3 docs

* Fix openai examples

* Clippy

---------

Co-authored-by: Eric Buehler <[email protected]>

* Add full Gemma 3n support! (EricLBuehler#1519)

* Add initial

* Loading for text model

* Add ple embeddings

* Add altup, laurel block

* Update rmsnorm

* Add mlp

* Update attn norm application

* Currently no kv shared

* Wire it up

* It runs

* Fix bf16

* Fix scaled embd

* Fixes for mean

* tmp

* Attn confirmed

* Fix target_magnitude

* Add shared kv

* Ok it works

* Remove npy

* Fix streaming

* Remove warnings

* Remove paged attn

* Refactor rope

* Add immediate isq

* Add vision & mproj

* Update image processor

* Vision merge runs, not correct

* Remove

* Add mobilenet v5

* Add multimodal vision embedding

* Fix load

* runs

* Fix gamma

* Works but just not vision tower

* It works!!

* Tweak

* Fix warnings

* Move vision tower

* Fix warn

* Update cache manager things

* Refactor

* Add audio model, it loads

* Add audio processing

* It runs at least

* tmp

* A bit better

* Audio works!!!!

* Fused attn in vision

* Clippy

* Update audio runner

* Optimized audio model

* Remove unused things

* Fix inputs processor bug

* Remove comments

* Clippy

* Small optimizations

* Format

* Correctly register modalities

* Add docs

* Update readme

* Runs there

* Fixed padding from Blaizzy/mlx-vlm#410

* Add better checks

* Fix sdpa n_kv_groups

* Vision encoder works!

* Rotate image

* Clippy

* Fix cuda loading

* Updated device mapper

* Fix overflow

* Fix dtype errors

* Refactor image/audio embeddings

* Fix metal

* Fix dtype mismatch

* Audio processing fixes

* Audio processing fixes

* Works

* Audio is good

* Fix boi/eoi too

* Embed the chat templates

* Better embedding accuracy in non f32

* More f32

* Support bf16 on metal

* Add more ISQ

* Fixed device map

* Clippy

* Gemma3n no paged attn

* Fix saturating sub

* Faster rmsnorm

* Use sdpa for vision model

* Fix ple bug

* Fix name

* Fix multiaudio

* Add matformer config loading

* Add docs

* Add support for matformer in auto device mapper

* Update docs

* Typos

* Tweak

* Tweak

* Fix multidevice

* Fix gemma3n text model auto device map

* Fix dims3

* Fix auto devic emap vision

* Non-metal keeps PLE on cpu

* Complete merge

* Vision dtype f16 -> f32

* Fix metal nm device

* Fix uqff

* Typos

* Reference uqff

* Fix tests

* Fix sequence length check (EricLBuehler#1546)

* update candle version (EricLBuehler#1545)

Co-authored-by: AlpineVibrations <[email protected]>

* add ios target to metal deps (EricLBuehler#1548)

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: edwko <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Guoqing Bao <[email protected]>
Co-authored-by: Michał Moskal <[email protected]>
Co-authored-by: Chen Mulong <[email protected]>
Co-authored-by: Steph Wolski <[email protected]>
Co-authored-by: omahs <[email protected]>
Co-authored-by: Viktor Szépe <[email protected]>
Co-authored-by: Matthew Haynes <[email protected]>
Co-authored-by: RageLtMan <[email protected]>
Co-authored-by: Brennan Kinney <[email protected]>
Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: Sbargaoui <[email protected]>
Co-authored-by: Gaétan Lepage <[email protected]>
Co-authored-by: Ammar Elsabe <[email protected]>
Co-authored-by: luke <[email protected]>
Co-authored-by: AlpineVibrations <[email protected]>
Co-authored-by: Michael Tissen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant