Skip to content

Commit d94134e

Browse files
committed
Add encoder embedding
1 parent 3853cfe commit d94134e

File tree

1 file changed

+25
-1
lines changed
  • mistralrs-core/src/vision_models/conformer

1 file changed

+25
-1
lines changed

mistralrs-core/src/vision_models/conformer/encoder.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,27 @@ impl EncoderLayer {
517517
}
518518
}
519519

520+
struct EncoderEmbedding {
521+
global_invstd: Tensor,
522+
global_mean: Tensor,
523+
}
524+
525+
impl EncoderEmbedding {
526+
fn new(vb: ShardedVarBuilder) -> Result<Self> {
527+
Ok(Self {
528+
global_invstd: vb.get_unchecked("global_invstd")?,
529+
global_mean: vb.get_unchecked("global_mean")?,
530+
})
531+
}
532+
533+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
534+
xs.broadcast_sub(&self.global_mean)?
535+
.broadcast_mul(&self.global_invstd)
536+
}
537+
}
538+
520539
pub struct Encoder {
540+
encoder_embedding: EncoderEmbedding,
521541
embed: NemoConvSubsampling,
522542
pos_embed: AbsolutePositionalEncoding,
523543
relative_attention_bias_layer: T5RelativeAttentionLogitBias,
@@ -555,7 +575,10 @@ impl Encoder {
555575
encoders.push(EncoderLayer::new(&cfg, vb.pp("encoders").pp(i))?);
556576
}
557577

578+
let encoder_embedding = EncoderEmbedding::new(vb.pp("encoder_embedding"))?;
579+
558580
Ok(Self {
581+
encoder_embedding,
559582
embed,
560583
pos_embed,
561584
relative_attention_bias_layer,
@@ -565,7 +588,8 @@ impl Encoder {
565588

566589
pub fn forward(&self, xs: &Tensor, mask: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
567590
// Forward through embeddings (subsampling)
568-
let (mut input_tensor, masks) = self.embed.forward(xs, mask)?;
591+
let xs = self.encoder_embedding.forward(xs)?;
592+
let (mut input_tensor, masks) = self.embed.forward(&xs, mask)?;
569593

570594
// Handle long sequences with unfolding
571595
let max_seq_len = 500;

0 commit comments

Comments
 (0)