Skip to content

Commit 3907feb

Browse files
committed
Fully loading speech stack
1 parent c9ac339 commit 3907feb

File tree

4 files changed

+106
-11
lines changed

4 files changed

+106
-11
lines changed

mistralrs-core/src/vision_models/conformer/encoder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,15 +536,15 @@ impl EncoderEmbedding {
536536
}
537537
}
538538

539-
pub struct Encoder {
539+
pub struct ConformerEncoder {
540540
encoder_embedding: EncoderEmbedding,
541541
embed: NemoConvSubsampling,
542542
pos_embed: AbsolutePositionalEncoding,
543543
relative_attention_bias_layer: T5RelativeAttentionLogitBias,
544544
encoders: Vec<EncoderLayer>,
545545
}
546546

547-
impl Encoder {
547+
impl ConformerEncoder {
548548
pub fn new(mut cfg: ConformerEncoderConfig, vb: ShardedVarBuilder) -> Result<Self> {
549549
assert_eq!(cfg.input_layer, "nemo_conv");
550550

mistralrs-core/src/vision_models/conformer/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,3 @@ pub mod config;
22
pub mod encoder;
33
pub mod nemo;
44
pub mod pos_embed;
5-
6-
pub use encoder::Encoder as ConformerEncoder;

mistralrs-core/src/vision_models/conformer/nemo.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use candle_core::{Result, Tensor};
24
use candle_nn::{Conv2dConfig, Linear, Module};
35
use mistralrs_quant::ShardedVarBuilder;
@@ -7,7 +9,7 @@ use crate::layers;
79
use super::config::NemoConvConfig;
810

911
pub struct NemoConvSubsampling {
10-
conv: Vec<Box<dyn Module>>,
12+
conv: Vec<Arc<dyn Module + Send + Sync>>,
1113
conv2d_subsampling: bool,
1214
out: Linear,
1315
subsampling_causal_cond: bool,
@@ -25,7 +27,7 @@ impl NemoConvSubsampling {
2527
["dw_striding", "striding", "striding_conv1d"].contains(&cfg.subsampling.as_str());
2628

2729
let mut in_channels = 1;
28-
let mut layers: Vec<Box<dyn Module>> = Vec::new();
30+
let mut layers: Vec<Arc<dyn Module + Send + Sync>> = Vec::new();
2931

3032
let stride = 2;
3133
let kernel_size = 3;
@@ -46,7 +48,7 @@ impl NemoConvSubsampling {
4648
let vb_layers = vb.pp("conv");
4749

4850
let mut idx = 0;
49-
layers.push(Box::new(layers::conv2d(
51+
layers.push(Arc::new(layers::conv2d(
5052
in_channels,
5153
cfg.conv_channels,
5254
kernel_size,
@@ -61,11 +63,11 @@ impl NemoConvSubsampling {
6163

6264
in_channels = cfg.conv_channels;
6365
idx += 1;
64-
layers.push(Box::new(cfg.activation));
66+
layers.push(Arc::new(cfg.activation));
6567

6668
for _ in 0..(sampling_num - 1) {
6769
idx += 1;
68-
layers.push(Box::new(layers::conv2d(
70+
layers.push(Arc::new(layers::conv2d(
6971
in_channels,
7072
in_channels,
7173
kernel_size,
@@ -79,7 +81,7 @@ impl NemoConvSubsampling {
7981
)?));
8082

8183
idx += 1;
82-
layers.push(Box::new(layers::conv2d(
84+
layers.push(Arc::new(layers::conv2d(
8385
in_channels,
8486
cfg.conv_channels,
8587
1,
@@ -93,7 +95,7 @@ impl NemoConvSubsampling {
9395
)?));
9496

9597
idx += 1;
96-
layers.push(Box::new(cfg.activation));
98+
layers.push(Arc::new(cfg.activation));
9799
}
98100
}
99101

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use std::{collections::HashMap, sync::Arc};
2+
3+
use candle_core::Result;
4+
use candle_nn::Module;
5+
use mistralrs_quant::ShardedVarBuilder;
6+
7+
use crate::{
8+
layers::{self, Activation},
9+
vision_models::{
10+
conformer::encoder::ConformerEncoder,
11+
phi4::config::{Phi4MMAudioConfig, Phi4MMAudioEmbedConfig},
12+
},
13+
};
14+
15+
use super::Phi4MMConfig;
16+
17+
pub(super) const AUDIO_SPECIAL_TOKEN_ID: f64 = 200011.;
18+
19+
#[derive(Eq, Hash, PartialEq)]
20+
pub enum AudioProjectionMode {
21+
/// If only speech
22+
Speech,
23+
/// If vision + speech or only vision (not sure why that is necesary though)
24+
Vision,
25+
}
26+
27+
pub struct AudioEmbedding {
28+
proj: HashMap<AudioProjectionMode, Vec<Arc<dyn Module + Send + Sync>>>,
29+
encoder: ConformerEncoder,
30+
}
31+
32+
impl AudioEmbedding {
33+
pub fn new(
34+
cfg: &Phi4MMConfig,
35+
audio_embd_config: &Phi4MMAudioEmbedConfig,
36+
vb: ShardedVarBuilder,
37+
) -> Result<Self> {
38+
let hidden_size = audio_embd_config.n_embd.unwrap_or(cfg.hidden_size);
39+
40+
let conformer_config = match &cfg.audio_processor {
41+
Some(Phi4MMAudioConfig { config, name }) if name == "cascades" => config,
42+
_ => candle_core::bail!("Must have audio processor (`cascades`)"),
43+
};
44+
let encoder = ConformerEncoder::new(conformer_config.clone(), vb.pp("encoder"))?;
45+
46+
// let audio_dim_in = conformer_config.input_size;
47+
let audio_dim_out = conformer_config.attention_dim;
48+
49+
let mut proj = HashMap::new();
50+
{
51+
assert_eq!(audio_embd_config.projection_cls, "mlp");
52+
53+
let dim_projection = hidden_size;
54+
let depth = 2;
55+
let linear_downsample_rate = audio_embd_config.downsample_rate;
56+
57+
let embedding_cls_vb = vb.pp("audio_projection");
58+
59+
let mut layers_for_speech: Vec<Arc<dyn Module + Send + Sync>> =
60+
vec![Arc::new(layers::linear(
61+
audio_dim_out * linear_downsample_rate,
62+
dim_projection,
63+
embedding_cls_vb.pp("speech").pp(0),
64+
)?)];
65+
for i in 1..depth {
66+
layers_for_speech.push(Arc::new(Activation::Gelu));
67+
layers_for_speech.push(Arc::new(layers::linear(
68+
dim_projection,
69+
dim_projection,
70+
embedding_cls_vb.pp("speech").pp(i + 1),
71+
)?));
72+
}
73+
74+
let mut layers_for_vision: Vec<Arc<dyn Module + Send + Sync>> =
75+
vec![Arc::new(layers::linear(
76+
audio_dim_out * linear_downsample_rate,
77+
dim_projection,
78+
embedding_cls_vb.pp("vision").pp(0),
79+
)?)];
80+
for i in 1..depth {
81+
layers_for_vision.push(Arc::new(Activation::Gelu));
82+
layers_for_vision.push(Arc::new(layers::linear(
83+
dim_projection,
84+
dim_projection,
85+
embedding_cls_vb.pp("vision").pp(i + 1),
86+
)?));
87+
}
88+
89+
proj.insert(AudioProjectionMode::Speech, layers_for_speech);
90+
proj.insert(AudioProjectionMode::Vision, layers_for_vision);
91+
}
92+
93+
Ok(Self { proj, encoder })
94+
}
95+
}

0 commit comments

Comments
 (0)