Skip to content

Commit 8e10e52

Browse files
committed
Merger
1 parent 3907feb commit 8e10e52

File tree

6 files changed

+215
-35
lines changed

6 files changed

+215
-35
lines changed

mistralrs-core/src/vision_models/conformer/encoder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ impl ConvModule {
348348
None
349349
};
350350

351-
let mut fix_len1 = false;
351+
let fix_len1;
352352
let ext_pw_conv_1d = if cfg.causal {
353353
if cfg.ext_pw_kernel_size > 1 {
354354
fix_len1 = true;
@@ -642,7 +642,7 @@ impl ConformerEncoder {
642642
}
643643

644644
fn unfold_tensor(xs_pad: &Tensor, max_seq_len: usize) -> Result<Tensor> {
645-
let (n, t, d) = xs_pad.dims3()?;
645+
let (_n, t, _d) = xs_pad.dims3()?;
646646

647647
// If sequence length is already <= max_seq_len, no need to unfold
648648
if t <= max_seq_len {

mistralrs-core/src/vision_models/conformer/nemo.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ use super::config::NemoConvConfig;
1010

1111
pub struct NemoConvSubsampling {
1212
conv: Vec<Arc<dyn Module + Send + Sync>>,
13-
conv2d_subsampling: bool,
1413
out: Linear,
15-
subsampling_causal_cond: bool,
1614
subsampling_factor: usize,
1715
}
1816

@@ -23,8 +21,6 @@ impl NemoConvSubsampling {
2321
}
2422

2523
let sampling_num = (cfg.subsampling_factor as f32).log2() as usize;
26-
let subsampling_causal_cond =
27-
["dw_striding", "striding", "striding_conv1d"].contains(&cfg.subsampling.as_str());
2824

2925
let mut in_channels = 1;
3026
let mut layers: Vec<Arc<dyn Module + Send + Sync>> = Vec::new();
@@ -114,13 +110,10 @@ impl NemoConvSubsampling {
114110
true,
115111
vb.pp("out"),
116112
)?;
117-
let conv2d_subsampling = false;
118113

119114
Ok(Self {
120115
conv: layers,
121-
conv2d_subsampling,
122116
out,
123-
subsampling_causal_cond,
124117
subsampling_factor: cfg.subsampling_factor,
125118
})
126119
}
@@ -135,7 +128,7 @@ impl NemoConvSubsampling {
135128
) -> usize {
136129
let add_pad = all_paddings as f32 - kernel_size as f32;
137130
let one = 1f32;
138-
for i in 0..repeat_num {
131+
for _ in 0..repeat_num {
139132
length = (length + add_pad) / (stride as f32) + one;
140133
if ceil_mode {
141134
length = length.ceil();

mistralrs-core/src/vision_models/phi4/audio_embedding.rs

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use std::{collections::HashMap, sync::Arc};
22

3-
use candle_core::Result;
3+
use candle_core::{DType, Device, IndexOp, Result, Tensor};
44
use candle_nn::Module;
5-
use mistralrs_quant::ShardedVarBuilder;
5+
use mistralrs_quant::{NonZeroOp, ShardedVarBuilder};
66

77
use crate::{
88
layers::{self, Activation},
@@ -16,7 +16,7 @@ use super::Phi4MMConfig;
1616

1717
pub(super) const AUDIO_SPECIAL_TOKEN_ID: f64 = 200011.;
1818

19-
#[derive(Eq, Hash, PartialEq)]
19+
#[derive(Eq, Hash, PartialEq, Debug, Clone, Copy)]
2020
pub enum AudioProjectionMode {
2121
/// If only speech
2222
Speech,
@@ -25,13 +25,16 @@ pub enum AudioProjectionMode {
2525
}
2626

2727
pub struct AudioEmbedding {
28+
wte: candle_nn::Embedding,
2829
proj: HashMap<AudioProjectionMode, Vec<Arc<dyn Module + Send + Sync>>>,
2930
encoder: ConformerEncoder,
31+
target_device_dtype: (Device, DType),
3032
}
3133

3234
impl AudioEmbedding {
3335
pub fn new(
3436
cfg: &Phi4MMConfig,
37+
wte: candle_nn::Embedding,
3538
audio_embd_config: &Phi4MMAudioEmbedConfig,
3639
vb: ShardedVarBuilder,
3740
) -> Result<Self> {
@@ -90,6 +93,123 @@ impl AudioEmbedding {
9093
proj.insert(AudioProjectionMode::Vision, layers_for_vision);
9194
}
9295

93-
Ok(Self { proj, encoder })
96+
Ok(Self {
97+
wte,
98+
proj,
99+
encoder,
100+
target_device_dtype: (vb.device().clone(), vb.dtype()),
101+
})
102+
}
103+
104+
fn get_audio_features(
105+
&self,
106+
input_embeds: &Tensor,
107+
audio_attention_mask: &Tensor,
108+
audio_projection_mode: &AudioProjectionMode,
109+
) -> Result<Tensor> {
110+
// Get audio features from encoder
111+
let (audio_features, _masks) = self
112+
.encoder
113+
.forward(input_embeds, Some(audio_attention_mask))?;
114+
115+
// Apply projection based on mode
116+
let projection_layers = self.proj.get(audio_projection_mode).ok_or_else(|| {
117+
candle_core::Error::Msg(format!(
118+
"Projection mode {:?} not found",
119+
audio_projection_mode
120+
))
121+
})?;
122+
123+
let mut audio_set_tensor = audio_features;
124+
for layer in projection_layers {
125+
audio_set_tensor = layer.forward(&audio_set_tensor)?;
126+
}
127+
128+
Ok(audio_set_tensor)
129+
}
130+
131+
pub fn forward(
132+
&self,
133+
input_ids: &Tensor,
134+
input_embeds: &Tensor,
135+
audio_embed_sizes: Vec<usize>,
136+
audio_attention_mask: &Tensor,
137+
audio_projection_mode: &AudioProjectionMode,
138+
) -> Result<Tensor> {
139+
// Reshape input_ids to 2D
140+
let input_shape = input_ids.shape();
141+
let input_ids = if input_shape.rank() > 2 {
142+
input_ids.reshape((
143+
input_shape.elem_count() / input_shape.dims()[input_shape.rank() - 1],
144+
input_shape.dims()[input_shape.rank() - 1],
145+
))?
146+
} else {
147+
input_ids.clone()
148+
};
149+
150+
let positions = input_ids.eq(AUDIO_SPECIAL_TOKEN_ID)?.nonzero()?;
151+
152+
// Get target device and dtype from projection layers
153+
let (target_device, target_dtype) = self.target_device_dtype.clone();
154+
155+
let audio_set_tensor = if positions.dim(0)? > 0 {
156+
// Convert to target device/dtype if needed
157+
let input_embeds = if input_embeds.device().same_device(&target_device)
158+
|| input_embeds.dtype() != target_dtype
159+
{
160+
input_embeds
161+
.to_device(&target_device)?
162+
.to_dtype(target_dtype)?
163+
} else {
164+
input_embeds.clone()
165+
};
166+
167+
self.get_audio_features(&input_embeds, audio_attention_mask, audio_projection_mode)?
168+
} else {
169+
// Return early if no audio tokens and not training
170+
return self.wte.forward(&input_ids);
171+
};
172+
173+
// Get initial hidden states from word embeddings
174+
let mut hidden_states = self.wte.forward(&input_ids)?;
175+
176+
// Verify that audio_embed_sizes sum matches positions count
177+
let total_audio_tokens = audio_embed_sizes.iter().sum::<usize>();
178+
if total_audio_tokens != positions.dim(0)? {
179+
return Err(candle_core::Error::Msg(format!(
180+
"Audio embed sizes sum ({}) doesn't match positions count ({})",
181+
total_audio_tokens,
182+
positions.dim(0)?
183+
)));
184+
}
185+
186+
let mut audio_sets = Vec::new();
187+
for (i, size) in audio_embed_sizes.into_iter().enumerate() {
188+
audio_sets.push(audio_set_tensor.i((i, size, ..))?);
189+
}
190+
let merged_audio_set_tensor = Tensor::cat(&audio_sets, 0)?;
191+
192+
let original_shape = hidden_states.shape().clone();
193+
let (hs_b, hs_l, hs_d) = hidden_states.dims3()?;
194+
let mut hidden_states_flat = hidden_states.reshape(((), hs_d))?;
195+
196+
// Get the equiv 0th and 1th rows of the positions_tuple
197+
let positions_transposed = positions.to_dtype(DType::F32)?;
198+
let positions_transposed_0 = positions_transposed.i((.., 0))?;
199+
let positions_transposed_1 = positions_transposed.i((.., 1))?;
200+
201+
let mut linear_index =
202+
((positions_transposed_0 * (hs_l * hs_b) as f64)? + positions_transposed_1)?;
203+
linear_index = linear_index.to_dtype(DType::U32)?;
204+
linear_index = linear_index.unsqueeze(1)?.repeat((1, hs_d))?;
205+
206+
let current_vals = hidden_states_flat.gather(&linear_index, 0)?;
207+
let delta = merged_audio_set_tensor.broadcast_sub(&current_vals)?;
208+
209+
hidden_states_flat = hidden_states_flat.scatter_add(&linear_index, &delta, 0)?;
210+
211+
hidden_states = hidden_states_flat.reshape(original_shape)?;
212+
213+
Ok(hidden_states)
94214
}
95215
}

mistralrs-core/src/vision_models/phi4/inputs_processor.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,12 @@ impl InputsProcessor for Phi4MMInputsProcessor {
203203
position_ids,
204204
pixel_values: None,
205205
model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
206-
image_sizes: None,
207-
image_attention_mask: None,
208206
input_image_embeds: None,
207+
image_attention_mask: None,
208+
image_sizes: None,
209+
input_audio_embeds: None, // TODO!
210+
audio_embed_sizes: None, // TODO!
211+
audio_attention_mask: None, // TODO!
209212
}),
210213
paged_attn_meta,
211214
flash_meta,
@@ -326,9 +329,12 @@ impl InputsProcessor for Phi4MMInputsProcessor {
326329
position_ids,
327330
pixel_values: pixel_values.clone(),
328331
model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
329-
image_sizes: image_sizes.clone(),
330-
image_attention_mask: pixel_attention_mask,
331332
input_image_embeds: pixel_values,
333+
image_attention_mask: pixel_attention_mask,
334+
image_sizes: image_sizes.clone(),
335+
input_audio_embeds: None, // TODO!
336+
audio_embed_sizes: None, // TODO!
337+
audio_attention_mask: None, // TODO!
332338
}),
333339
paged_attn_meta,
334340
flash_meta,

mistralrs-core/src/vision_models/phi4/mm_embedding.rs

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@ use candle_core::{Result, Tensor, D};
22
use candle_nn::Module;
33
use mistralrs_quant::ShardedVarBuilder;
44

5-
use crate::utils::unvarbuilder::UnVarBuilder;
5+
use crate::{
6+
utils::unvarbuilder::UnVarBuilder,
7+
vision_models::phi4::{
8+
audio_embedding::{AudioProjectionMode, AUDIO_SPECIAL_TOKEN_ID},
9+
image_embedding::IMAGE_SPECIAL_TOKEN_ID,
10+
},
11+
};
612

713
use super::{audio_embedding::AudioEmbedding, image_embedding::ImageEmbedding, Phi4MMConfig};
814

@@ -34,6 +40,7 @@ impl Phi4MMImageAudioEmbedding {
3440
let audio_embed = if let Some(audio_embd_config) = &cfg.embd_layer.audio_embd_layer {
3541
Some(AudioEmbedding::new(
3642
cfg,
43+
wte.clone(),
3744
audio_embd_config,
3845
vb.pp("audio_embed"),
3946
)?)
@@ -52,29 +59,59 @@ impl Phi4MMImageAudioEmbedding {
5259
pub fn forward(
5360
&self,
5461
input_ids: &Tensor,
55-
input_image_embeds: &Tensor,
62+
input_image_embeds: Option<&Tensor>,
5663
image_attention_mask: Option<&Tensor>,
5764
image_sizes: Option<Vec<(u32, u32)>>,
65+
input_audio_embeds: Option<&Tensor>,
66+
audio_embed_sizes: Option<Vec<usize>>,
67+
audio_attention_mask: Option<&Tensor>,
68+
audio_projection_mode: AudioProjectionMode,
5869
) -> Result<Tensor> {
5970
assert!(-MAX_INPUT_ID < self.image_input_id);
6071

6172
let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
6273

63-
let image_hidden_states = if let Some(image_embed) = &self.image_embed {
64-
Some(image_embed.forward(
74+
let image_hidden_states = match &self.image_embed {
75+
Some(image_embed) if input_image_embeds.is_some() => Some(image_embed.forward(
6576
&input_ids,
66-
input_image_embeds,
77+
input_image_embeds.expect("input_image_embeds"),
6778
image_attention_mask,
6879
image_sizes,
69-
)?)
70-
} else {
71-
None
80+
)?),
81+
_ => None,
7282
};
7383

74-
match image_hidden_states {
75-
Some(image_hidden_states) => Ok(image_hidden_states),
84+
let audio_hidden_states = match &self.audio_embed {
85+
Some(audio_embed) if input_audio_embeds.is_some() => Some(audio_embed.forward(
86+
&input_ids,
87+
input_audio_embeds.expect("input_audio_embeds"),
88+
audio_embed_sizes.expect("audio_embed_sizes"),
89+
audio_attention_mask.expect("audio_attention_mask"),
90+
&audio_projection_mode,
91+
)?),
92+
_ => None,
93+
};
94+
95+
let image_position_mask = input_ids.eq(IMAGE_SPECIAL_TOKEN_ID)?;
96+
let non_image_position_mask = input_ids.eq(AUDIO_SPECIAL_TOKEN_ID)?;
97+
98+
match (image_hidden_states, audio_hidden_states) {
99+
(Some(image_hidden_states), Some(audio_hidden_states)) => {
100+
// Merge
101+
image_hidden_states.broadcast_mul(
102+
&image_position_mask
103+
.to_dtype(image_hidden_states.dtype())?
104+
.unsqueeze(D::Minus1)?,
105+
)? + audio_hidden_states.broadcast_mul(
106+
&non_image_position_mask
107+
.to_dtype(audio_hidden_states.dtype())?
108+
.unsqueeze(D::Minus1)?,
109+
)?
110+
}
111+
(Some(image_hidden_states), None) => Ok(image_hidden_states),
112+
(None, Some(audio_hidden_states)) => Ok(audio_hidden_states),
76113

77-
None => self.wte.forward(&input_ids),
114+
(None, None) => self.wte.forward(&input_ids),
78115
}
79116
}
80117

0 commit comments

Comments
 (0)