Skip to content

Commit 3853cfe

Browse files
committed
Add the forward methods
1 parent 8244e07 commit 3853cfe

File tree

4 files changed

+338
-14
lines changed

4 files changed

+338
-14
lines changed

mistralrs-core/src/vision_models/conformer/config.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ serde_default_fn!(Activation, default_nemo_activation, Activation::Relu);
3333
serde_default_fn!(bool, default_nemo_is_causal, false);
3434
serde_default_fn!(usize, fake_default_sentinel, usize::MAX);
3535

36-
#[derive(Serialize, Deserialize, Debug)]
36+
#[derive(Serialize, Deserialize, Debug, Clone)]
3737
pub struct RelativeAttentionBiasArgs {
3838
pub t5_bias_max_distance: Option<usize>,
3939
pub t5_bias_symmetric: Option<bool>,
4040
#[serde(rename = "type")]
4141
pub tp: String,
4242
}
4343

44-
#[derive(Serialize, Deserialize, Debug)]
44+
#[derive(Serialize, Deserialize, Debug, Clone)]
4545
pub struct NemoConvConfig {
4646
#[serde(default = "default_subsampling")]
4747
pub subsampling: String,
@@ -61,12 +61,12 @@ pub struct NemoConvConfig {
6161
pub is_causal: bool,
6262
}
6363

64-
#[derive(Serialize, Deserialize, Debug)]
64+
#[derive(Serialize, Deserialize, Debug, Clone)]
6565
pub struct EncoderEmbeddingConfig {
6666
pub input_size: usize,
6767
}
6868

69-
#[derive(Serialize, Deserialize, Debug)]
69+
#[derive(Serialize, Deserialize, Debug, Clone)]
7070
pub struct ConformerEncoderConfig {
7171
pub input_size: usize,
7272
pub chunk_size: i32,

mistralrs-core/src/vision_models/conformer/encoder.rs

Lines changed: 251 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::Arc;
22

3-
use candle_core::{Result, Tensor};
4-
use candle_nn::{BatchNorm, Conv1d, Conv1dConfig, LayerNorm, Linear};
3+
use candle_core::{IndexOp, Result, Tensor, D};
4+
use candle_nn::{BatchNorm, Conv1d, Conv1dConfig, LayerNorm, Linear, ModuleT};
55
use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
66

77
use crate::{
@@ -148,10 +148,16 @@ impl FeedForward {
148148
}
149149

150150
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
151-
xs.apply(&self.layer_norm)?
152-
.apply(&self.up)?
153-
.apply(&self.act)?
154-
.apply(&self.down)
151+
let normed = xs.apply(&self.layer_norm)?;
152+
let projected = normed.apply(&self.up)?;
153+
154+
// GLU: split in half and gate
155+
let chunks = projected.chunk(2, D::Minus1)?;
156+
let x = &chunks[0];
157+
let gate = chunks[1].apply(&self.act)?;
158+
let gated = (x * gate)?;
159+
160+
gated.apply(&self.down)
155161
}
156162
}
157163

@@ -209,6 +215,7 @@ struct GLUPointWiseConv {
209215
ext_pw_conv_1d: Conv1d,
210216
act: Activation,
211217
b1_b2: Option<(Tensor, Tensor)>,
218+
cfg: ConformerEncoderConfig,
212219
}
213220

214221
impl GLUPointWiseConv {
@@ -253,16 +260,50 @@ impl GLUPointWiseConv {
253260
ext_pw_conv_1d,
254261
act: cfg.conv_glu_type,
255262
b1_b2,
263+
cfg: cfg.clone(),
256264
})
257265
}
266+
267+
fn forward(&self, x: &Tensor) -> Result<Tensor> {
268+
// Input is (B, T, D), need (B, D, T) for conv1d
269+
let x = x.transpose(1, 2)?;
270+
let mut x = x.apply(&self.ext_pw_conv_1d)?;
271+
272+
// Handle causal padding removal
273+
if self.cfg.causal && self.cfg.kernel_size > 1 {
274+
let seq_len = x.dim(2)?;
275+
x = x.i((.., .., ..(seq_len - (self.cfg.kernel_size - 1))))?;
276+
}
277+
278+
// Split for GLU
279+
let chunks = x.chunk(2, 1)?; // Split along channel dim
280+
let first_half = &chunks[0];
281+
let second_half = &chunks[1];
282+
283+
let result = if let Some((b1, b2)) = &self.b1_b2 {
284+
let first_with_bias = first_half.broadcast_add(b1)?;
285+
let second_with_bias = second_half.broadcast_add(b2)?;
286+
first_with_bias.mul(&second_with_bias.apply(&self.act)?)?
287+
} else {
288+
first_half.mul(&second_half.apply(&self.act)?)?
289+
};
290+
291+
// Back to (B, T, D)
292+
result.transpose(1, 2)
293+
}
258294
}
259295

260296
struct ConvModule {
261297
layer_norm: LayerNorm,
262298
bn_layer: Option<BatchNorm>,
299+
ln1: Option<Linear>,
263300
ln2: Option<Linear>,
264301
dw_sep_conv_1d: DepthWiseSeperableConv1d,
265302
glu: GLUPointWiseConv,
303+
ext_pw_conv_1d: Conv1d,
304+
cfg: ConformerEncoderConfig,
305+
act: Activation,
306+
fix_len1: bool,
266307
}
267308

268309
impl ConvModule {
@@ -287,6 +328,8 @@ impl ConvModule {
287328

288329
let dw_sep_conv_1d = DepthWiseSeperableConv1d::new(cfg, padding, vb.pp("dw_sep_conv_1d"))?;
289330

331+
assert_ne!(cfg.ext_pw_out_channel, 0);
332+
290333
let ln2 = if cfg.depthwise_seperable_out_channel != 0
291334
&& cfg.attention_dim != cfg.depthwise_seperable_out_channel
292335
{
@@ -356,11 +399,65 @@ impl ConvModule {
356399
Ok(Self {
357400
layer_norm,
358401
bn_layer,
402+
ln1,
359403
ln2,
360404
dw_sep_conv_1d,
361405
glu,
406+
ext_pw_conv_1d,
407+
cfg: cfg.clone(),
408+
act: cfg.conv_activation,
409+
fix_len1,
362410
})
363411
}
412+
413+
fn forward(&self, x: &Tensor) -> Result<Tensor> {
414+
let mut x = x.apply(&self.layer_norm)?;
415+
416+
// Use GLU
417+
x = self.glu.forward(&x)?;
418+
if self.cfg.causal && self.cfg.ext_pw_kernel_size > 1 {
419+
let seq_len = x.dim(1)?;
420+
x = x.i((.., ..(seq_len - (self.cfg.ext_pw_kernel_size - 1)), ..))?;
421+
}
422+
if let Some(ln1) = &self.ln1 {
423+
x = x.apply(ln1)?;
424+
}
425+
426+
// Apply depthwise separable conv
427+
x = x.transpose(1, 2)?; // (B, T, D) -> (B, D, T)
428+
x = self.dw_sep_conv_1d.forward(&x)?;
429+
430+
if self.cfg.causal && self.cfg.kernel_size > 1 {
431+
let seq_len = x.dim(2)?;
432+
x = x.i((.., .., ..(seq_len - (self.cfg.kernel_size - 1))))?;
433+
}
434+
435+
if let Some(ln2) = &self.ln2 {
436+
x = x.transpose(1, 2)?; // (B, D, T) -> (B, T, D)
437+
x = x.apply(ln2)?;
438+
x = x.transpose(1, 2)?; // (B, T, D) -> (B, D, T)
439+
}
440+
441+
if let Some(bn) = &self.bn_layer {
442+
x = bn.forward_t(&x, false)?;
443+
}
444+
445+
x = x.apply(&self.act)?;
446+
447+
x = x.apply(&self.ext_pw_conv_1d)?;
448+
if self.fix_len1 {
449+
let seq_len = x.dim(2)?;
450+
x = x.i((.., .., ..(seq_len - (self.cfg.ext_pw_kernel_size - 1))))?;
451+
}
452+
if let Some(ln1) = &self.ln1 {
453+
x = x.transpose(1, 2)?;
454+
x = x.apply(ln1)?;
455+
x = x.transpose(1, 2)?;
456+
}
457+
x = x.transpose(1, 2)?; // Back to (B, T, D)
458+
459+
Ok(x)
460+
}
364461
}
365462

366463
struct EncoderLayer {
@@ -390,6 +487,34 @@ impl EncoderLayer {
390487
conv,
391488
})
392489
}
490+
fn forward(
491+
&self,
492+
x: &Tensor,
493+
mask: Option<&Tensor>,
494+
relative_attention_bias: Option<&Tensor>,
495+
) -> Result<Tensor> {
496+
// First feed forward (0.5x)
497+
let ff_in_out = self.feed_forward_in.forward(x)?;
498+
let mut x = (x + (ff_in_out * 0.5)?)?;
499+
500+
// Self attention with pre-norm
501+
let norm_x = x.apply(&self.layer_norm_att)?;
502+
let attn_out = self
503+
.self_attn
504+
.forward(&norm_x, mask, relative_attention_bias)?;
505+
x = (x + attn_out)?;
506+
507+
// Conv module
508+
let conv_out = self.conv.forward(&x)?;
509+
x = (x + conv_out)?;
510+
511+
// Second feed forward (0.5x)
512+
let ff_out_out = self.feed_forward_out.forward(&x)?;
513+
x = (x + (ff_out_out * 0.5)?)?;
514+
515+
// Final layer norm
516+
x.apply(&self.layer_norm)
517+
}
393518
}
394519

395520
pub struct Encoder {
@@ -437,4 +562,124 @@ impl Encoder {
437562
encoders,
438563
})
439564
}
565+
566+
pub fn forward(&self, xs: &Tensor, mask: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
567+
// Forward through embeddings (subsampling)
568+
let (mut input_tensor, masks) = self.embed.forward(xs, mask)?;
569+
570+
// Handle long sequences with unfolding
571+
let max_seq_len = 500;
572+
let (ori_bz, seq_len, d) = input_tensor.dims3()?;
573+
let unfolded = seq_len > max_seq_len;
574+
575+
// Outside of the `if` block as it's needed below
576+
let mut chunk_pad_size = 0;
577+
if unfolded {
578+
// Pad to multiple of max_seq_len
579+
chunk_pad_size = if seq_len % max_seq_len > 0 {
580+
max_seq_len - (seq_len % max_seq_len)
581+
} else {
582+
0
583+
};
584+
585+
if chunk_pad_size > 0 {
586+
input_tensor = input_tensor.pad_with_zeros(D::Minus1, 0, chunk_pad_size)?;
587+
}
588+
589+
// Unfold into chunks
590+
input_tensor = unfold_tensor(&input_tensor, max_seq_len)?;
591+
}
592+
593+
// Apply positional encoding
594+
input_tensor = self.pos_embed.forward(&input_tensor)?;
595+
596+
// Compute relative attention bias if available
597+
let relative_attention_bias = self.relative_attention_bias_layer.forward(&input_tensor)?;
598+
599+
// Apply encoder layers
600+
for layer in &self.encoders {
601+
input_tensor = layer.forward(
602+
&input_tensor,
603+
masks.as_ref(),
604+
Some(&relative_attention_bias),
605+
)?;
606+
}
607+
608+
// Handle unfolding restoration
609+
if unfolded {
610+
input_tensor = input_tensor.reshape((ori_bz, seq_len + chunk_pad_size, d))?;
611+
if chunk_pad_size > 0 {
612+
input_tensor = input_tensor.i((.., ..seq_len, ..))?;
613+
}
614+
}
615+
616+
Ok((input_tensor, masks))
617+
}
618+
}
619+
620+
fn unfold_tensor(xs_pad: &Tensor, max_seq_len: usize) -> Result<Tensor> {
621+
let (n, t, d) = xs_pad.dims3()?;
622+
623+
// If sequence length is already <= max_seq_len, no need to unfold
624+
if t <= max_seq_len {
625+
return Ok(xs_pad.clone());
626+
}
627+
628+
// xs_pad.transpose(-1, -2) # convert to N, D, T
629+
let xs_pad = xs_pad.transpose(1, 2)?; // (N, T, D) -> (N, D, T)
630+
631+
// Unfold the last dimension (T) with size=max_seq_len and step=max_seq_len
632+
// This creates sliding windows of size max_seq_len with step max_seq_len
633+
let xs_pad = xs_pad.unfold(2, max_seq_len, max_seq_len)?;
634+
// Shape is now (N, D, T', max_seq_len) where T' = T // max_seq_len
635+
636+
let (n, d, t_prime, _) = xs_pad.dims4()?;
637+
638+
// Permute to (N, T', max_seq_len, D) - equivalent to permute(0, 2, 3, 1)
639+
let xs_pad = xs_pad.permute((0, 2, 3, 1))?;
640+
641+
// Reshape to (N*T', max_seq_len, D)
642+
let xs_pad = xs_pad.reshape((n * t_prime, max_seq_len, d))?;
643+
644+
Ok(xs_pad)
645+
}
646+
647+
#[cfg(test)]
648+
mod tests {
649+
use super::*;
650+
use candle_core::Device;
651+
652+
#[test]
653+
fn test_unfold_tensor() -> Result<()> {
654+
let device = Device::Cpu;
655+
656+
// Test case 1: T > max_seq_len
657+
let xs = Tensor::arange(0f32, 24f32, &device)?.reshape((2, 6, 2))?; // (N=2, T=6, D=2)
658+
let result = unfold_tensor(&xs, 3)?; // max_seq_len=3
659+
assert_eq!(result.dims(), &[4, 3, 2]); // (N*T'=2*2, max_seq_len=3, D=2)
660+
661+
// Test case 2: T <= max_seq_len
662+
let xs = Tensor::arange(0f32, 12f32, &device)?.reshape((2, 3, 2))?; // (N=2, T=3, D=2)
663+
let result = unfold_tensor(&xs, 5)?; // max_seq_len=5
664+
assert_eq!(result.dims(), &[2, 3, 2]); // Should return original shape
665+
666+
// Test case 3: T == max_seq_len
667+
let xs = Tensor::arange(0f32, 12f32, &device)?.reshape((2, 3, 2))?; // (N=2, T=3, D=2)
668+
let result = unfold_tensor(&xs, 3)?; // max_seq_len=3
669+
assert_eq!(result.dims(), &[2, 3, 2]); // (N*T'=2*1, max_seq_len=3, D=2)
670+
671+
Ok(())
672+
}
673+
674+
#[test]
675+
fn test_unfold_tensor_larger() -> Result<()> {
676+
let device = Device::Cpu;
677+
678+
// Test with larger tensor
679+
let xs = Tensor::arange(0f32, 120f32, &device)?.reshape((2, 10, 6))?; // (N=2, T=10, D=6)
680+
let result = unfold_tensor(&xs, 4)?; // max_seq_len=4, T'=10//4=2
681+
assert_eq!(result.dims(), &[4, 4, 6]); // (N*T'=2*2, max_seq_len=4, D=6)
682+
683+
Ok(())
684+
}
440685
}

0 commit comments

Comments
 (0)