Skip to content

Commit 167f1b5

Browse files
committed
Support batching
1 parent ff67bf8 commit 167f1b5

File tree

1 file changed

+25
-26
lines changed

1 file changed

+25
-26
lines changed

mistralrs-core/src/vision_models/phi4/inputs_processor.rs

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -719,38 +719,37 @@ impl Phi4MMInputsProcessor {
719719
let has_audio = seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32));
720720

721721
if has_audio {
722-
// Convert multi-channel audio to mono by averaging channels
723-
let (audio_data, sample_rate) = if let Some(mut audios) = seq.take_audios() {
724-
if let Some(audio) = audios.pop() {
722+
if let Some(audios) = seq.take_audios() {
723+
for audio in audios.into_iter() {
724+
// Convert multi-channel audio to mono by averaging channels
725725
let samples = audio.to_mono();
726726

727-
(samples, audio.sample_rate)
728-
} else {
729-
candle_core::bail!("No audios in `process_audio_for_sequences`");
727+
// Extract features
728+
let features = self.extract_audio_features(&samples, audio.sample_rate)?;
729+
let audio_frames = features.len() * self.audio_feat_stride;
730+
731+
let embed_size = self.compute_audio_embed_size(
732+
audio_frames,
733+
self.audio_compression_rate,
734+
self.audio_downsample_rate,
735+
);
736+
737+
// Convert to tensor
738+
let features_len = features.len();
739+
let features_flat: Vec<f32> = features.into_iter().flatten().collect();
740+
let features_tensor = Tensor::from_slice(
741+
&features_flat,
742+
(features_len, AUDIO_FEATURE_SIZE),
743+
device,
744+
)?;
745+
746+
audio_features_list.push(features_tensor);
747+
audio_embed_sizes_list.push(embed_size);
748+
audio_frames_list.push(audio_frames);
730749
}
731750
} else {
732751
candle_core::bail!("No audios in `process_audio_for_sequences`");
733752
};
734-
735-
// Extract features
736-
let features = self.extract_audio_features(&audio_data, sample_rate)?;
737-
let audio_frames = features.len() * self.audio_feat_stride;
738-
739-
let embed_size = self.compute_audio_embed_size(
740-
audio_frames,
741-
self.audio_compression_rate,
742-
self.audio_downsample_rate,
743-
);
744-
745-
// Convert to tensor
746-
let features_len = features.len();
747-
let features_flat: Vec<f32> = features.into_iter().flatten().collect();
748-
let features_tensor =
749-
Tensor::from_slice(&features_flat, (features_len, AUDIO_FEATURE_SIZE), device)?;
750-
751-
audio_features_list.push(features_tensor);
752-
audio_embed_sizes_list.push(embed_size);
753-
audio_frames_list.push(audio_frames);
754753
}
755754
}
756755

0 commit comments

Comments
 (0)