Skip to content

Commit 07dafc0

Browse files
authored
Fast Metal-specific quantization method: AFQ (#1264)
* Add mlx quantized kernels * Add mlx quantized kernels * Kernel launcher * Add AFQ isq quant and dequant * Some quantmethod things * Begin to implement the qmm caller * Clippy * Much faster * Cache kernels * Docs * Clippy * Add it to uqff
1 parent b286f3e commit 07dafc0

File tree

25 files changed

+5188
-192
lines changed

25 files changed

+5188
-192
lines changed

.typos.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ extend-ignore-identifiers-re = [
66
"Nd",
77
"nin",
88
"cudaDevAttrMaxSharedMemoryPerBlockOptin",
9-
"_thw"
9+
"_thw",
10+
"thr",
11+
"nd",
12+
"uneeded"
1013
]
1114

1215
[files]

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis
3131
- Check out UQFF for prequantized models of various methods!
3232
- Models can be found [here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c).
3333

34+
- 🔥 Try out AFQ for blazingly fast Metal performance!
35+
36+
```
37+
./mistralrs-server -i --isq afq8 plain -m meta-llama/Llama-3.2-3B-Instruct
38+
```
39+
3440
- 🔍🌐 Easily add web search capabilities to your models! Compatible with OpenAI's `web_search_options` parameter: [documentation](docs/WEB_SEARCH.md)
3541
3642
```

docs/ISQ.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@ An API is exposed on the Python and Rust APIs which provide the ability to dynam
66

77
To set the ISQ type for individual layers, use a model [`topology`](TOPOLOGY.md).
88

9+
> Note: 🔥 AFQ (affine) quantization is fast on **Metal**
10+
911
## ISQ quantization types
12+
- AFQ2
13+
- AFQ3
14+
- AFQ4
15+
- AFQ6
16+
- AFQ8
1017
- Q4_0
1118
- Q4_1
1219
- Q5_0

docs/UQFF.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ The following quantization formats are supported in UQFF. One can, of course, be
5454
- FP8:
5555
- FP8 E4M3 (4-bit exponent, 3-bit mantissa)
5656

57+
- AFQ quantized (🔥 AFQ is fast on **Metal**):
58+
- AFQ2
59+
- AFQ3
60+
- AFQ4
61+
- AFQ6
62+
- AFQ8
63+
5764
## Loading a UQFF model
5865

5966
To load a UQFF model, one should specify the filename. This will be located based on the model ID, and can

mistralrs-core/src/pipeline/isq.rs

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ use candle_core::{quantized, Context, Device, Tensor};
1414
use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
1515
use itertools::Itertools;
1616
use mistralrs_quant::{
17-
CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul, HqqLayer,
18-
IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType, ReplicatedLayer,
19-
RowParallelLayer, UnquantLinear,
17+
AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
18+
HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
19+
ReplicatedLayer, RowParallelLayer, UnquantLinear,
2020
};
2121
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
2222
use regex::Regex;
@@ -63,10 +63,15 @@ pub fn parse_isq_value(s: &str) -> Result<IsqType, String> {
6363
"hqq8" => IsqType::HQQ8,
6464
"hqq4" => IsqType::HQQ4,
6565
"fp8" => IsqType::F8E4M3,
66+
"afq8" => IsqType::AFQ8,
67+
"afq6" => IsqType::AFQ6,
68+
"afq4" => IsqType::AFQ4,
69+
"afq3" => IsqType::AFQ3,
70+
"afq2" => IsqType::AFQ2,
6671
// "hqq3" => IsqType::HQQ3,
6772
// "hqq2" => IsqType::HQQ2,
6873
// "hqq1" => IsqType::HQQ1,
69-
_ => return Err(format!("ISQ type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`.")),
74+
_ => return Err(format!("ISQ type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
7075
};
7176
#[cfg(feature = "cuda")]
7277
{
@@ -442,19 +447,14 @@ pub trait IsqModel {
442447
// Get the MINIMUM of the max isq threads the quant method
443448
let mut minimum_max_threads = {
444449
let current_rayon_threads = rayon::current_num_threads();
445-
tensors
446-
.iter()
447-
.map(|(q, _)| {
448-
if let Some(dtype) = dtype {
449-
q.get_max_isq_cpu_threads(dtype)
450-
.map(usize::from)
451-
.unwrap_or(current_rayon_threads)
452-
} else {
453-
current_rayon_threads
454-
}
455-
})
456-
.min()
457-
.unwrap_or(current_rayon_threads)
450+
if let Some(dtype) = dtype {
451+
dtype
452+
.get_max_isq_cpu_threads()
453+
.map(usize::from)
454+
.unwrap_or(current_rayon_threads)
455+
} else {
456+
current_rayon_threads
457+
}
458458
};
459459
if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
460460
minimum_max_threads = 1;
@@ -807,6 +807,12 @@ pub trait IsqModel {
807807
&comm,
808808
guard.clone(),
809809
)?,
810+
QuantizedSerdeType::Afq => AfqLayer::deserialize(
811+
Cow::from(artifact),
812+
&devices[i],
813+
&comm,
814+
guard.clone(),
815+
)?,
810816
}
811817
}
812818
};
@@ -874,6 +880,12 @@ pub trait IsqModel {
874880
&comm,
875881
guard.clone(),
876882
)?,
883+
QuantizedSerdeType::Afq => AfqLayer::deserialize(
884+
Cow::from(artifact),
885+
&devices[i],
886+
&comm,
887+
guard.clone(),
888+
)?,
877889
}
878890
}
879891
};

mistralrs-core/src/pipeline/normal.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use candle_core::{Device, Tensor, Var};
4343
use hf_hub::Cache;
4444
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
4545
use indicatif::MultiProgress;
46-
use mistralrs_quant::{GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
46+
use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
4747
use rand_isaac::Isaac64Rng;
4848
use regex_automata::meta::Regex;
4949
use std::any::Any;
@@ -365,6 +365,10 @@ impl Loader for NormalLoader {
365365
}
366366
QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
367367
QuantizedSerdeType::Unquant => 1,
368+
QuantizedSerdeType::Afq => {
369+
AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
370+
.pack_factor(dtype)
371+
}
368372
};
369373
total_pack_factors += pack_factor;
370374
}

mistralrs-core/src/pipeline/vision.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use candle_core::{Device, Tensor, Var};
3838
use hf_hub::Cache;
3939
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
4040
use indicatif::MultiProgress;
41-
use mistralrs_quant::{GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
41+
use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
4242
use rand_isaac::Isaac64Rng;
4343
use regex_automata::meta::Regex;
4444
use std::any::Any;
@@ -305,6 +305,10 @@ impl Loader for VisionLoader {
305305
}
306306
QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
307307
QuantizedSerdeType::Unquant => 1,
308+
QuantizedSerdeType::Afq => {
309+
AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
310+
.pack_factor(dtype)
311+
}
308312
};
309313
total_pack_factors += pack_factor;
310314
}

0 commit comments

Comments
 (0)