Skip to content

Commit 558eb99

Browse files
committed
User specified
1 parent 8250b33 commit 558eb99

File tree

10 files changed

+326
-24
lines changed

10 files changed

+326
-24
lines changed

mistralrs-core/src/engine/add_request.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,21 @@ impl Engine {
122122
ref images,
123123
messages: _,
124124
enable_thinking: _,
125+
audios: _,
125126
} => Some(images.clone()),
126127
_ => None,
127128
};
128129

130+
let audios = match request.messages {
131+
RequestMessage::VisionChat {
132+
images: _,
133+
messages: _,
134+
enable_thinking: _,
135+
ref audios,
136+
} => Some(audios.clone()),
137+
_ => None,
138+
};
139+
129140
let matcher = Arc::new(handle_seq_error!(
130141
ToolCallingMatcher::new(request.tool_choice.unwrap_or(ToolChoice::Auto),),
131142
request.response
@@ -157,6 +168,7 @@ impl Engine {
157168
}
158169
| RequestMessage::VisionChat {
159170
images: _,
171+
audios: _,
160172
messages,
161173
enable_thinking,
162174
} => {
@@ -497,6 +509,7 @@ impl Engine {
497509
None
498510
},
499511
images.clone(),
512+
audios.clone(),
500513
block_size,
501514
Some(matcher.clone()),
502515
image_generation_format,

mistralrs-core/src/lib.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,10 @@ pub use pipeline::{
9797
UQFF_MULTI_FILE_DELIMITER,
9898
};
9999
pub use request::{
100-
ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
101-
LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
102-
TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
100+
ApproximateUserLocation, AudioInput, Constraint, DetokenizationRequest,
101+
ImageGenerationResponseFormat, LlguidanceGrammar, MessageContent, NormalRequest, Request,
102+
RequestMessage, SearchContextSize, TokenizationRequest, WebSearchOptions,
103+
WebSearchUserLocation,
103104
};
104105
pub use response::*;
105106
pub use sampler::{

mistralrs-core/src/pipeline/amoe.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ fn new_dummy_seq(
584584
None,
585585
None,
586586
images,
587+
None,
587588
None, // TODO incorrect for PagedAttention
588589
None,
589590
None,

mistralrs-core/src/request.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ pub enum RequestMessage {
5050
},
5151
CompletionTokens(Vec<u32>),
5252
VisionChat {
53-
#[serde(skip)] // TODO!!!!
53+
#[serde(skip)] // TODO
5454
images: Vec<image::DynamicImage>,
55+
#[serde(skip)] // TODO
56+
audios: Vec<AudioInput>,
5557
messages: Vec<IndexMap<String, MessageContent>>,
5658
enable_thinking: Option<bool>,
5759
},
@@ -116,6 +118,42 @@ pub struct WebSearchOptions {
116118
pub extract_description: Option<String>,
117119
}
118120

121+
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
122+
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
123+
/// Raw audio input consisting of PCM samples and a sample rate.
124+
pub struct AudioInput {
125+
pub samples: Vec<f32>,
126+
pub sample_rate: u32,
127+
}
128+
129+
impl AudioInput {
130+
pub fn read_wav(wav_path: &str) -> anyhow::Result<Self> {
131+
let mut reader = hound::WavReader::open(wav_path)
132+
.map_err(|e| anyhow::Error::msg(format!("Failed to load audio: {}", e)))?;
133+
let spec = reader.spec();
134+
135+
let samples: Vec<f32> = match spec.sample_format {
136+
hound::SampleFormat::Float => reader
137+
.samples::<f32>()
138+
.map(|s| s.map_err(|e| anyhow::Error::msg(e.to_string())))
139+
.collect::<std::result::Result<_, _>>()?,
140+
141+
hound::SampleFormat::Int => reader
142+
.samples::<i16>() // read as integers
143+
.map(|s| {
144+
s.map(|v| v as f32 / i16::MAX as f32) // scale to –1.0…1.0
145+
.map_err(|e| candle_core::Error::Msg(e.to_string()))
146+
})
147+
.collect::<std::result::Result<_, _>>()?,
148+
};
149+
150+
Ok(Self {
151+
samples,
152+
sample_rate: spec.sample_rate,
153+
})
154+
}
155+
}
156+
119157
#[derive(Clone, Serialize, Deserialize)]
120158
/// A normal request request to the `MistralRs`.
121159
/// - `messages`: Messages for the request

mistralrs-core/src/sequence.rs

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::{
44
pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
55
response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
66
sampler::{Logprobs, Sampler},
7-
ChatCompletionResponse, Usage,
7+
AudioInput, ChatCompletionResponse, Usage,
88
};
99
use crate::{
1010
paged_attention::{BlockEngineSequence, LogicalTokenBlock},
@@ -171,6 +171,53 @@ pub struct SequenceImages {
171171
hashes: Vec<u64>,
172172
}
173173

174+
#[derive(Clone)]
175+
pub struct SequenceAudios {
176+
audios: Vec<AudioInput>,
177+
hashes: Vec<u64>,
178+
}
179+
180+
impl SequenceAudios {
181+
fn new(input_audios: Vec<AudioInput>) -> Self {
182+
let hashes = input_audios.iter().map(|a| {
183+
let mut hasher = DefaultHasher::new();
184+
for s in &a.samples {
185+
s.to_bits().hash(&mut hasher);
186+
}
187+
a.sample_rate.hash(&mut hasher);
188+
hasher.finish()
189+
});
190+
Self {
191+
hashes: hashes.collect(),
192+
audios: input_audios,
193+
}
194+
}
195+
196+
fn clone_audios(&self) -> Vec<AudioInput> {
197+
self.audios.clone()
198+
}
199+
200+
fn audios(&self) -> &[AudioInput] {
201+
&self.audios
202+
}
203+
204+
fn audios_mut(&mut self) -> &mut Vec<AudioInput> {
205+
&mut self.audios
206+
}
207+
208+
fn hashes(&self) -> &[u64] {
209+
&self.hashes
210+
}
211+
212+
fn keep_num_audios(&mut self, audios_to_keep: usize) {
213+
if self.audios.len() > audios_to_keep {
214+
let start = self.audios.len() - audios_to_keep;
215+
self.audios = self.audios[start..].to_vec();
216+
self.hashes = self.hashes[start..].to_vec();
217+
}
218+
}
219+
}
220+
174221
impl SequenceImages {
175222
fn new(input_images: Vec<image::DynamicImage>) -> Self {
176223
let hashes = input_images.iter().map(|x| {
@@ -211,6 +258,7 @@ impl SequenceImages {
211258
// Holds all multimodal (vision/diffusion) data for a Sequence.
212259
pub struct MultimodalData {
213260
pub input_images: Option<SequenceImages>,
261+
pub input_audios: Option<SequenceAudios>,
214262
pub cached_pixel_values: Option<Tensor>,
215263
pub cached_img_thw: Option<Tensor>,
216264
pub cached_vid_thw: Option<Tensor>,
@@ -222,11 +270,13 @@ pub struct MultimodalData {
222270
impl MultimodalData {
223271
pub fn new(
224272
input_images: Option<Vec<image::DynamicImage>>,
273+
input_audios: Option<Vec<AudioInput>>,
225274
image_gen_response_format: Option<ImageGenerationResponseFormat>,
226275
diffusion_params: Option<DiffusionGenerationParams>,
227276
) -> Self {
228277
MultimodalData {
229278
input_images: input_images.map(SequenceImages::new),
279+
input_audios: input_audios.map(SequenceAudios::new),
230280
cached_pixel_values: None,
231281
cached_img_thw: None,
232282
cached_vid_thw: None,
@@ -268,6 +318,40 @@ impl MultimodalData {
268318
.is_some_and(|imgs| !imgs.images().is_empty())
269319
}
270320

321+
pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
322+
if let Some(input_audios) = self.input_audios.as_mut() {
323+
let mut audios = Vec::new();
324+
std::mem::swap(&mut audios, input_audios.audios_mut());
325+
Some(audios)
326+
} else {
327+
None
328+
}
329+
}
330+
331+
pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
332+
self.input_audios.as_ref().map(|a| a.clone_audios())
333+
}
334+
335+
pub fn audios(&self) -> Option<&[AudioInput]> {
336+
self.input_audios.as_ref().map(|a| a.audios())
337+
}
338+
339+
pub fn audio_hashes(&self) -> Option<&[u64]> {
340+
self.input_audios.as_ref().map(|a| a.hashes())
341+
}
342+
343+
pub fn has_audios(&self) -> bool {
344+
self.input_audios
345+
.as_ref()
346+
.is_some_and(|a| !a.audios().is_empty())
347+
}
348+
349+
pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
350+
if let Some(auds) = self.input_audios.as_mut() {
351+
auds.keep_num_audios(audios_to_keep)
352+
}
353+
}
354+
271355
pub fn keep_num_images(&mut self, images_to_keep: usize) {
272356
if let Some(imgs) = self.input_images.as_mut() {
273357
imgs.keep_num_images(images_to_keep)
@@ -422,6 +506,7 @@ impl Sequence {
422506
suffix: Option<String>,
423507
prefix: Option<String>,
424508
input_images: Option<Vec<image::DynamicImage>>,
509+
input_audios: Option<Vec<AudioInput>>,
425510
// Paged attention
426511
block_size: Option<usize>,
427512
//
@@ -492,6 +577,7 @@ impl Sequence {
492577
// Multimodal data
493578
multimodal: MultimodalData::new(
494579
input_images,
580+
input_audios,
495581
image_gen_response_format,
496582
diffusion_params,
497583
),
@@ -967,6 +1053,30 @@ impl Sequence {
9671053
self.multimodal.has_images()
9681054
}
9691055

1056+
pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
1057+
self.multimodal.take_audios()
1058+
}
1059+
1060+
pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
1061+
self.multimodal.clone_audios()
1062+
}
1063+
1064+
pub fn audios(&self) -> Option<&[AudioInput]> {
1065+
self.multimodal.audios()
1066+
}
1067+
1068+
pub fn audio_hashes(&self) -> Option<&[u64]> {
1069+
self.multimodal.audio_hashes()
1070+
}
1071+
1072+
pub fn has_audios(&self) -> bool {
1073+
self.multimodal.has_audios()
1074+
}
1075+
1076+
pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
1077+
self.multimodal.keep_num_audios(audios_to_keep)
1078+
}
1079+
9701080
/// Keep these last n images
9711081
pub fn keep_num_images(&mut self, images_to_keep: usize) {
9721082
self.multimodal.keep_num_images(images_to_keep)

mistralrs-core/src/vision_models/phi4/inputs_processor.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ impl Phi4MMInputsProcessor {
720720

721721
fn process_audio_for_sequences(
722722
&self,
723-
input_seqs: &[&mut Sequence],
723+
input_seqs: &mut [&mut Sequence],
724724
device: &Device,
725725
) -> AudioProcessingResult {
726726
// Check if any sequence has audio tokens
@@ -737,12 +737,19 @@ impl Phi4MMInputsProcessor {
737737
let mut audio_frames_list = Vec::new();
738738

739739
// Process audio for each sequence that needs it
740-
for seq in input_seqs.iter() {
740+
for seq in input_seqs.iter_mut() {
741741
let has_audio = seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32));
742742

743743
if has_audio {
744-
// Load dummy audio (TODO: make this per-sequence)
745-
let (audio_data, sample_rate) = self.load_dummy_audio()?;
744+
let (audio_data, sample_rate) = if let Some(mut audios) = seq.take_audios() {
745+
if let Some(audio) = audios.pop() {
746+
(audio.samples, audio.sample_rate)
747+
} else {
748+
self.load_dummy_audio()?
749+
}
750+
} else {
751+
self.load_dummy_audio()?
752+
};
746753

747754
// Extract features
748755
let features = self.extract_audio_features(&audio_data, sample_rate)?;

mistralrs-pyo3/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,7 @@ impl Runner {
10691069
RequestMessage::VisionChat {
10701070
messages: messages_vec,
10711071
images,
1072+
audios: Vec::new(),
10721073
enable_thinking: request.enable_thinking,
10731074
}
10741075
} else {

mistralrs-server-core/src/chat_completion.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ pub async fn parse_request(
442442
RequestMessage::VisionChat {
443443
messages,
444444
images,
445+
audios: Vec::new(),
445446
enable_thinking: oairequest.enable_thinking,
446447
}
447448
} else {

0 commit comments

Comments
 (0)