@@ -517,7 +517,27 @@ impl EncoderLayer {
517
517
}
518
518
}
519
519
520
+ struct EncoderEmbedding {
521
+ global_invstd : Tensor ,
522
+ global_mean : Tensor ,
523
+ }
524
+
525
+ impl EncoderEmbedding {
526
+ fn new ( vb : ShardedVarBuilder ) -> Result < Self > {
527
+ Ok ( Self {
528
+ global_invstd : vb. get_unchecked ( "global_invstd" ) ?,
529
+ global_mean : vb. get_unchecked ( "global_mean" ) ?,
530
+ } )
531
+ }
532
+
533
+ fn forward ( & self , xs : & Tensor ) -> Result < Tensor > {
534
+ xs. broadcast_sub ( & self . global_mean ) ?
535
+ . broadcast_mul ( & self . global_invstd )
536
+ }
537
+ }
538
+
520
539
pub struct Encoder {
540
+ encoder_embedding : EncoderEmbedding ,
521
541
embed : NemoConvSubsampling ,
522
542
pos_embed : AbsolutePositionalEncoding ,
523
543
relative_attention_bias_layer : T5RelativeAttentionLogitBias ,
@@ -555,7 +575,10 @@ impl Encoder {
555
575
encoders. push ( EncoderLayer :: new ( & cfg, vb. pp ( "encoders" ) . pp ( i) ) ?) ;
556
576
}
557
577
578
+ let encoder_embedding = EncoderEmbedding :: new ( vb. pp ( "encoder_embedding" ) ) ?;
579
+
558
580
Ok ( Self {
581
+ encoder_embedding,
559
582
embed,
560
583
pos_embed,
561
584
relative_attention_bias_layer,
@@ -565,7 +588,8 @@ impl Encoder {
565
588
566
589
pub fn forward ( & self , xs : & Tensor , mask : Option < & Tensor > ) -> Result < ( Tensor , Option < Tensor > ) > {
567
590
// Forward through embeddings (subsampling)
568
- let ( mut input_tensor, masks) = self . embed . forward ( xs, mask) ?;
591
+ let xs = self . encoder_embedding . forward ( xs) ?;
592
+ let ( mut input_tensor, masks) = self . embed . forward ( & xs, mask) ?;
569
593
570
594
// Handle long sequences with unfolding
571
595
let max_seq_len = 500 ;
0 commit comments