Skip to content

Commit dd9031b

Browse files
committed
Add modalities registry
1 parent 29b146c commit dd9031b

File tree

9 files changed

+151
-7
lines changed

9 files changed

+151
-7
lines changed

mistralrs-core/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub use pipeline::{
9090
DiffusionLoaderBuilder, DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig,
9191
GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader,
9292
IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
93-
LoraAdapterPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths,
93+
LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind, ModelPaths,
9494
MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
9595
NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
9696
SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
@@ -134,6 +134,7 @@ pub struct MistralRsConfig {
134134
pub kind: ModelKind,
135135
pub device: Device,
136136
pub category: ModelCategory,
137+
pub modalities: Modalities,
137138
}
138139

139140
/// The MistralRs struct handles sending requests to the engine.
@@ -340,10 +341,19 @@ impl MistralRs {
340341

341342
let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
342343
let device = pipeline.try_lock().unwrap().device();
344+
let modalities = pipeline
345+
.try_lock()
346+
.unwrap()
347+
.get_metadata()
348+
.modalities
349+
.clone();
350+
info!("Pipeline input modalities are {:?}", &modalities.input);
351+
info!("Pipeline output modalities are {:?}", &modalities.output);
343352
let config = MistralRsConfig {
344353
kind,
345354
device,
346355
category: category.clone(),
356+
modalities,
347357
};
348358

349359
let engine_handler = thread::spawn(move || {

mistralrs-core/src/pipeline/diffusion.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use super::{
88
use crate::device_map::DeviceMapper;
99
use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
1010
use crate::paged_attention::AttentionImplementation;
11-
use crate::pipeline::ChatTemplate;
11+
use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
1212
use crate::prefix_cacher::PrefixCacheManagerV2;
1313
use crate::sequence::Sequence;
1414
use crate::utils::varbuilder_utils::DeviceForLoadTensor;
@@ -222,6 +222,10 @@ impl Loader for DiffusionLoader {
222222
cache_engine: None,
223223
prompt_chunksize: None,
224224
model_metadata: None,
225+
modalities: Modalities {
226+
input: vec![SupportedModality::Text],
227+
output: vec![SupportedModality::Vision],
228+
},
225229
}),
226230
dummy_cache: EitherCache::Full(Cache::new(0, false)),
227231
})))

mistralrs-core/src/pipeline/ggml.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ use crate::device_map::DeviceMapper;
1111
use crate::kv_cache::FullCacheManager;
1212
use crate::lora::Ordering;
1313
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
14-
use crate::pipeline::get_chat_template;
1514
use crate::pipeline::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
1615
use crate::pipeline::sampling::sample_and_add_toks;
16+
use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
1717
use crate::pipeline::{ChatTemplate, LocalModelPaths};
1818
use crate::prefix_cacher::PrefixCacheManagerV2;
1919
use crate::sequence::Sequence;
@@ -396,6 +396,10 @@ impl Loader for GGMLLoader {
396396
cache_engine: None,
397397
prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
398398
model_metadata: None,
399+
modalities: Modalities {
400+
input: vec![SupportedModality::Text],
401+
output: vec![SupportedModality::Text],
402+
},
399403
}),
400404
})))
401405
}

mistralrs-core/src/pipeline/gguf.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ use crate::paged_attention::{
1919
calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
2020
};
2121
use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
22-
use crate::pipeline::get_chat_template;
2322
use crate::pipeline::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2423
use crate::pipeline::loaders::DeviceMappedModelLoader;
2524
use crate::pipeline::sampling::sample_and_add_toks;
2625
use crate::pipeline::ChatTemplate;
26+
use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
2727
use crate::prefix_cacher::PrefixCacheManagerV2;
2828
use crate::sequence::Sequence;
2929
use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
@@ -565,6 +565,10 @@ impl Loader for GGUFLoader {
565565
cache_engine,
566566
prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
567567
model_metadata: Some(Arc::new(model_config_metadata)),
568+
modalities: Modalities {
569+
input: vec![SupportedModality::Text],
570+
output: vec![SupportedModality::Text],
571+
},
568572
}),
569573
mapper: pipeline_mapper,
570574
})))

mistralrs-core/src/pipeline/loaders/vision_loaders.rs

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ use crate::pipeline::isq::IsqModelLoader;
2727
use crate::pipeline::loaders::AutoDeviceMapParams;
2828
use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
2929
use crate::pipeline::{
30-
EitherCache, IsqModel, MultimodalPromptPrefixer, Processor, ProcessorCreator,
30+
EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
31+
SupportedModality,
3132
};
3233
use crate::utils::varbuilder_utils::DeviceForLoadTensor;
3334
use crate::vision_models::clip::ClipConfig;
@@ -104,6 +105,7 @@ pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoa
104105
// Default is false, specific model must override.
105106
false
106107
}
108+
fn modalities(&self, config: &str) -> Result<Modalities>;
107109
fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
108110
fn get_device_for_tensor(
109111
&self,
@@ -311,6 +313,10 @@ impl VisionModelLoader for AutoVisionLoader {
311313
.supports_paged_attention(config)
312314
}
313315

316+
fn modalities(&self, config: &str) -> Result<Modalities> {
317+
Self::get_loader(config)?.modalities(config)
318+
}
319+
314320
fn supports_prefix_cacher(&self, config: &str) -> bool {
315321
Self::get_loader(config)
316322
.expect("AutoVisionLoader")
@@ -499,6 +505,12 @@ impl VisionModelLoader for Phi3VLoader {
499505
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
500506
Arc::new(Phi3VPrefixer)
501507
}
508+
fn modalities(&self, _config: &str) -> Result<Modalities> {
509+
Ok(Modalities {
510+
input: vec![SupportedModality::Text, SupportedModality::Vision],
511+
output: vec![SupportedModality::Text],
512+
})
513+
}
502514
}
503515

504516
impl IsqModelLoader for Phi3VLoader {
@@ -771,6 +783,12 @@ impl VisionModelLoader for Idefics2Loader {
771783
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
772784
Arc::new(Idefics2Prefixer)
773785
}
786+
fn modalities(&self, _config: &str) -> Result<Modalities> {
787+
Ok(Modalities {
788+
input: vec![SupportedModality::Text, SupportedModality::Vision],
789+
output: vec![SupportedModality::Text],
790+
})
791+
}
774792
}
775793

776794
impl IsqModelLoader for Idefics2Loader {
@@ -1109,6 +1127,12 @@ impl VisionModelLoader for LLaVANextLoader {
11091127
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
11101128
Arc::new(LLaVANextPrefixer)
11111129
}
1130+
fn modalities(&self, _config: &str) -> Result<Modalities> {
1131+
Ok(Modalities {
1132+
input: vec![SupportedModality::Text, SupportedModality::Vision],
1133+
output: vec![SupportedModality::Text],
1134+
})
1135+
}
11121136
}
11131137

11141138
impl IsqModelLoader for LLaVANextLoader {
@@ -1371,6 +1395,12 @@ impl VisionModelLoader for LLaVALoader {
13711395
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
13721396
Arc::new(LLaVAPrefixer)
13731397
}
1398+
fn modalities(&self, _config: &str) -> Result<Modalities> {
1399+
Ok(Modalities {
1400+
input: vec![SupportedModality::Text, SupportedModality::Vision],
1401+
output: vec![SupportedModality::Text],
1402+
})
1403+
}
13741404
}
13751405

13761406
impl IsqModelLoader for LLaVALoader {
@@ -1625,6 +1655,12 @@ impl VisionModelLoader for VLlamaLoader {
16251655
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
16261656
Arc::new(VLlamaPrefixer)
16271657
}
1658+
fn modalities(&self, _config: &str) -> Result<Modalities> {
1659+
Ok(Modalities {
1660+
input: vec![SupportedModality::Text, SupportedModality::Vision],
1661+
output: vec![SupportedModality::Text],
1662+
})
1663+
}
16281664
}
16291665

16301666
impl IsqModelLoader for VLlamaLoader {
@@ -2009,6 +2045,12 @@ impl VisionModelLoader for Qwen2VLLoader {
20092045
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
20102046
Arc::new(Qwen2VLPrefixer)
20112047
}
2048+
fn modalities(&self, _config: &str) -> Result<Modalities> {
2049+
Ok(Modalities {
2050+
input: vec![SupportedModality::Text, SupportedModality::Vision],
2051+
output: vec![SupportedModality::Text],
2052+
})
2053+
}
20122054
}
20132055

20142056
impl IsqModelLoader for Qwen2VLLoader {
@@ -2297,6 +2339,12 @@ impl VisionModelLoader for Idefics3Loader {
22972339
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
22982340
Arc::new(Idefics3Prefixer)
22992341
}
2342+
fn modalities(&self, _config: &str) -> Result<Modalities> {
2343+
Ok(Modalities {
2344+
input: vec![SupportedModality::Text, SupportedModality::Vision],
2345+
output: vec![SupportedModality::Text],
2346+
})
2347+
}
23002348
}
23012349

23022350
impl IsqModelLoader for Idefics3Loader {
@@ -2606,6 +2654,12 @@ impl VisionModelLoader for MiniCpmOLoader {
26062654
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
26072655
Arc::new(MiniCpmOPrefixer)
26082656
}
2657+
fn modalities(&self, _config: &str) -> Result<Modalities> {
2658+
Ok(Modalities {
2659+
input: vec![SupportedModality::Text, SupportedModality::Vision],
2660+
output: vec![SupportedModality::Text],
2661+
})
2662+
}
26092663
}
26102664

26112665
impl IsqModelLoader for MiniCpmOLoader {
@@ -2892,6 +2946,16 @@ impl VisionModelLoader for Phi4MMLoader {
28922946
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
28932947
Arc::new(Phi4MMPrefixer)
28942948
}
2949+
fn modalities(&self, _config: &str) -> Result<Modalities> {
2950+
Ok(Modalities {
2951+
input: vec![
2952+
SupportedModality::Text,
2953+
SupportedModality::Vision,
2954+
SupportedModality::Audio,
2955+
],
2956+
output: vec![SupportedModality::Text],
2957+
})
2958+
}
28952959
}
28962960

28972961
impl IsqModelLoader for Phi4MMLoader {
@@ -3213,6 +3277,12 @@ impl VisionModelLoader for Qwen2_5VLLoader {
32133277
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
32143278
Arc::new(Qwen2_5VLPrefixer)
32153279
}
3280+
fn modalities(&self, _config: &str) -> Result<Modalities> {
3281+
Ok(Modalities {
3282+
input: vec![SupportedModality::Text, SupportedModality::Vision],
3283+
output: vec![SupportedModality::Text],
3284+
})
3285+
}
32163286
}
32173287

32183288
impl IsqModelLoader for Qwen2_5VLLoader {
@@ -3500,6 +3570,12 @@ impl VisionModelLoader for Gemma3Loader {
35003570
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
35013571
Arc::new(Gemma3Prefixer)
35023572
}
3573+
fn modalities(&self, _config: &str) -> Result<Modalities> {
3574+
Ok(Modalities {
3575+
input: vec![SupportedModality::Text, SupportedModality::Vision],
3576+
output: vec![SupportedModality::Text],
3577+
})
3578+
}
35033579
}
35043580

35053581
impl IsqModelLoader for Gemma3Loader {
@@ -3827,6 +3903,12 @@ impl VisionModelLoader for Mistral3Loader {
38273903
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
38283904
Arc::new(Mistral3Prefixer)
38293905
}
3906+
fn modalities(&self, _config: &str) -> Result<Modalities> {
3907+
Ok(Modalities {
3908+
input: vec![SupportedModality::Text, SupportedModality::Vision],
3909+
output: vec![SupportedModality::Text],
3910+
})
3911+
}
38303912
}
38313913

38323914
impl IsqModelLoader for Mistral3Loader {
@@ -4143,6 +4225,12 @@ impl VisionModelLoader for VLlama4Loader {
41434225
fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
41444226
Arc::new(VLlama4Prefixer)
41454227
}
4228+
fn modalities(&self, _config: &str) -> Result<Modalities> {
4229+
Ok(Modalities {
4230+
input: vec![SupportedModality::Text, SupportedModality::Vision],
4231+
output: vec![SupportedModality::Text],
4232+
})
4233+
}
41464234
}
41474235

41484236
impl IsqModelLoader for VLlama4Loader {

mistralrs-core/src/pipeline/mod.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline}
5757
pub use speech::{SpeechLoader, SpeechPipeline};
5858
use std::any::Any;
5959
use std::collections::HashMap;
60+
use std::fmt::Debug;
6061
use std::num::NonZeroUsize;
6162
use std::sync::Arc;
6263
use std::time::{Duration, Instant};
@@ -76,6 +77,29 @@ pub use crate::kv_cache::{
7677
Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
7778
};
7879

80+
#[derive(Clone)]
81+
pub enum SupportedModality {
82+
Text,
83+
Audio,
84+
Vision,
85+
}
86+
87+
impl Debug for SupportedModality {
88+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89+
match self {
90+
Self::Text => write!(f, "📝 Text"),
91+
Self::Audio => write!(f, "🔊 Audio"),
92+
Self::Vision => write!(f, "🖼️ Vision"),
93+
}
94+
}
95+
}
96+
97+
#[derive(Debug, Clone)]
98+
pub struct Modalities {
99+
pub input: Vec<SupportedModality>,
100+
pub output: Vec<SupportedModality>,
101+
}
102+
79103
pub struct GeneralMetadata {
80104
pub max_seq_len: usize,
81105
/// Only None if it doesn't make sense for the model
@@ -94,6 +118,7 @@ pub struct GeneralMetadata {
94118
pub cache_engine: Option<CacheEngine>,
95119
pub prompt_chunksize: Option<NonZeroUsize>,
96120
pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
121+
pub modalities: Modalities,
97122
}
98123

99124
impl GeneralMetadata {

mistralrs-core/src/pipeline/normal.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ use crate::kv_cache::{FullCacheManager, NormalCacheManager};
2222
use crate::lora::Ordering;
2323
use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
2424
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25-
use crate::pipeline::get_chat_template;
2625
use crate::pipeline::isq::UqffFullSer;
2726
use crate::pipeline::loaders::auto_device_map;
2827
use crate::pipeline::loaders::QuantizationConfigShim;
2928
use crate::pipeline::sampling::sample_and_add_toks;
3029
use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30+
use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
3131
use crate::pipeline::{ChatTemplate, LocalModelPaths};
3232
use crate::prefix_cacher::PrefixCacheManagerV2;
3333
use crate::sequence::Sequence;
@@ -874,6 +874,10 @@ impl Loader for NormalLoader {
874874
cache_engine,
875875
prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
876876
model_metadata: Some(model_metadata),
877+
modalities: Modalities {
878+
input: vec![SupportedModality::Text],
879+
output: vec![SupportedModality::Text],
880+
},
877881
}),
878882
topology: self.config.topology.clone(),
879883
silent,

mistralrs-core/src/pipeline/speech.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use super::{
66
PreProcessingMixin, Processor, TokenSource,
77
};
88
use crate::device_map::DeviceMapper;
9-
use crate::pipeline::ChatTemplate;
9+
use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
1010
use crate::prefix_cacher::PrefixCacheManagerV2;
1111
use crate::sequence::Sequence;
1212
use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
@@ -306,6 +306,10 @@ impl Loader for SpeechLoader {
306306
cache_engine: None,
307307
prompt_chunksize: None,
308308
model_metadata: None,
309+
modalities: Modalities {
310+
input: vec![SupportedModality::Text],
311+
output: vec![SupportedModality::Audio],
312+
},
309313
}),
310314
dummy_cache: EitherCache::Full(Cache::new(0, false)),
311315
cfg: self

0 commit comments

Comments
 (0)