1
1
use std:: sync:: Arc ;
2
2
3
- use candle_core:: { Result , Tensor } ;
4
- use candle_nn:: { BatchNorm , Conv1d , Conv1dConfig , LayerNorm , Linear } ;
3
+ use candle_core:: { IndexOp , Result , Tensor , D } ;
4
+ use candle_nn:: { BatchNorm , Conv1d , Conv1dConfig , LayerNorm , Linear , ModuleT } ;
5
5
use mistralrs_quant:: { QuantMethod , ShardedVarBuilder } ;
6
6
7
7
use crate :: {
@@ -148,10 +148,16 @@ impl FeedForward {
148
148
}
149
149
150
150
fn forward ( & self , xs : & Tensor ) -> Result < Tensor > {
151
- xs. apply ( & self . layer_norm ) ?
152
- . apply ( & self . up ) ?
153
- . apply ( & self . act ) ?
154
- . apply ( & self . down )
151
+ let normed = xs. apply ( & self . layer_norm ) ?;
152
+ let projected = normed. apply ( & self . up ) ?;
153
+
154
+ // GLU: split in half and gate
155
+ let chunks = projected. chunk ( 2 , D :: Minus1 ) ?;
156
+ let x = & chunks[ 0 ] ;
157
+ let gate = chunks[ 1 ] . apply ( & self . act ) ?;
158
+ let gated = ( x * gate) ?;
159
+
160
+ gated. apply ( & self . down )
155
161
}
156
162
}
157
163
@@ -209,6 +215,7 @@ struct GLUPointWiseConv {
209
215
ext_pw_conv_1d : Conv1d ,
210
216
act : Activation ,
211
217
b1_b2 : Option < ( Tensor , Tensor ) > ,
218
+ cfg : ConformerEncoderConfig ,
212
219
}
213
220
214
221
impl GLUPointWiseConv {
@@ -253,16 +260,50 @@ impl GLUPointWiseConv {
253
260
ext_pw_conv_1d,
254
261
act : cfg. conv_glu_type ,
255
262
b1_b2,
263
+ cfg : cfg. clone ( ) ,
256
264
} )
257
265
}
266
+
267
+ fn forward ( & self , x : & Tensor ) -> Result < Tensor > {
268
+ // Input is (B, T, D), need (B, D, T) for conv1d
269
+ let x = x. transpose ( 1 , 2 ) ?;
270
+ let mut x = x. apply ( & self . ext_pw_conv_1d ) ?;
271
+
272
+ // Handle causal padding removal
273
+ if self . cfg . causal && self . cfg . kernel_size > 1 {
274
+ let seq_len = x. dim ( 2 ) ?;
275
+ x = x. i ( ( .., .., ..( seq_len - ( self . cfg . kernel_size - 1 ) ) ) ) ?;
276
+ }
277
+
278
+ // Split for GLU
279
+ let chunks = x. chunk ( 2 , 1 ) ?; // Split along channel dim
280
+ let first_half = & chunks[ 0 ] ;
281
+ let second_half = & chunks[ 1 ] ;
282
+
283
+ let result = if let Some ( ( b1, b2) ) = & self . b1_b2 {
284
+ let first_with_bias = first_half. broadcast_add ( b1) ?;
285
+ let second_with_bias = second_half. broadcast_add ( b2) ?;
286
+ first_with_bias. mul ( & second_with_bias. apply ( & self . act ) ?) ?
287
+ } else {
288
+ first_half. mul ( & second_half. apply ( & self . act ) ?) ?
289
+ } ;
290
+
291
+ // Back to (B, T, D)
292
+ result. transpose ( 1 , 2 )
293
+ }
258
294
}
259
295
260
296
struct ConvModule {
261
297
layer_norm : LayerNorm ,
262
298
bn_layer : Option < BatchNorm > ,
299
+ ln1 : Option < Linear > ,
263
300
ln2 : Option < Linear > ,
264
301
dw_sep_conv_1d : DepthWiseSeperableConv1d ,
265
302
glu : GLUPointWiseConv ,
303
+ ext_pw_conv_1d : Conv1d ,
304
+ cfg : ConformerEncoderConfig ,
305
+ act : Activation ,
306
+ fix_len1 : bool ,
266
307
}
267
308
268
309
impl ConvModule {
@@ -287,6 +328,8 @@ impl ConvModule {
287
328
288
329
let dw_sep_conv_1d = DepthWiseSeperableConv1d :: new ( cfg, padding, vb. pp ( "dw_sep_conv_1d" ) ) ?;
289
330
331
+ assert_ne ! ( cfg. ext_pw_out_channel, 0 ) ;
332
+
290
333
let ln2 = if cfg. depthwise_seperable_out_channel != 0
291
334
&& cfg. attention_dim != cfg. depthwise_seperable_out_channel
292
335
{
@@ -356,11 +399,65 @@ impl ConvModule {
356
399
Ok ( Self {
357
400
layer_norm,
358
401
bn_layer,
402
+ ln1,
359
403
ln2,
360
404
dw_sep_conv_1d,
361
405
glu,
406
+ ext_pw_conv_1d,
407
+ cfg : cfg. clone ( ) ,
408
+ act : cfg. conv_activation ,
409
+ fix_len1,
362
410
} )
363
411
}
412
+
413
+ fn forward ( & self , x : & Tensor ) -> Result < Tensor > {
414
+ let mut x = x. apply ( & self . layer_norm ) ?;
415
+
416
+ // Use GLU
417
+ x = self . glu . forward ( & x) ?;
418
+ if self . cfg . causal && self . cfg . ext_pw_kernel_size > 1 {
419
+ let seq_len = x. dim ( 1 ) ?;
420
+ x = x. i ( ( .., ..( seq_len - ( self . cfg . ext_pw_kernel_size - 1 ) ) , ..) ) ?;
421
+ }
422
+ if let Some ( ln1) = & self . ln1 {
423
+ x = x. apply ( ln1) ?;
424
+ }
425
+
426
+ // Apply depthwise separable conv
427
+ x = x. transpose ( 1 , 2 ) ?; // (B, T, D) -> (B, D, T)
428
+ x = self . dw_sep_conv_1d . forward ( & x) ?;
429
+
430
+ if self . cfg . causal && self . cfg . kernel_size > 1 {
431
+ let seq_len = x. dim ( 2 ) ?;
432
+ x = x. i ( ( .., .., ..( seq_len - ( self . cfg . kernel_size - 1 ) ) ) ) ?;
433
+ }
434
+
435
+ if let Some ( ln2) = & self . ln2 {
436
+ x = x. transpose ( 1 , 2 ) ?; // (B, D, T) -> (B, T, D)
437
+ x = x. apply ( ln2) ?;
438
+ x = x. transpose ( 1 , 2 ) ?; // (B, T, D) -> (B, D, T)
439
+ }
440
+
441
+ if let Some ( bn) = & self . bn_layer {
442
+ x = bn. forward_t ( & x, false ) ?;
443
+ }
444
+
445
+ x = x. apply ( & self . act ) ?;
446
+
447
+ x = x. apply ( & self . ext_pw_conv_1d ) ?;
448
+ if self . fix_len1 {
449
+ let seq_len = x. dim ( 2 ) ?;
450
+ x = x. i ( ( .., .., ..( seq_len - ( self . cfg . ext_pw_kernel_size - 1 ) ) ) ) ?;
451
+ }
452
+ if let Some ( ln1) = & self . ln1 {
453
+ x = x. transpose ( 1 , 2 ) ?;
454
+ x = x. apply ( ln1) ?;
455
+ x = x. transpose ( 1 , 2 ) ?;
456
+ }
457
+ x = x. transpose ( 1 , 2 ) ?; // Back to (B, T, D)
458
+
459
+ Ok ( x)
460
+ }
364
461
}
365
462
366
463
struct EncoderLayer {
@@ -390,6 +487,34 @@ impl EncoderLayer {
390
487
conv,
391
488
} )
392
489
}
490
+ fn forward (
491
+ & self ,
492
+ x : & Tensor ,
493
+ mask : Option < & Tensor > ,
494
+ relative_attention_bias : Option < & Tensor > ,
495
+ ) -> Result < Tensor > {
496
+ // First feed forward (0.5x)
497
+ let ff_in_out = self . feed_forward_in . forward ( x) ?;
498
+ let mut x = ( x + ( ff_in_out * 0.5 ) ?) ?;
499
+
500
+ // Self attention with pre-norm
501
+ let norm_x = x. apply ( & self . layer_norm_att ) ?;
502
+ let attn_out = self
503
+ . self_attn
504
+ . forward ( & norm_x, mask, relative_attention_bias) ?;
505
+ x = ( x + attn_out) ?;
506
+
507
+ // Conv module
508
+ let conv_out = self . conv . forward ( & x) ?;
509
+ x = ( x + conv_out) ?;
510
+
511
+ // Second feed forward (0.5x)
512
+ let ff_out_out = self . feed_forward_out . forward ( & x) ?;
513
+ x = ( x + ( ff_out_out * 0.5 ) ?) ?;
514
+
515
+ // Final layer norm
516
+ x. apply ( & self . layer_norm )
517
+ }
393
518
}
394
519
395
520
pub struct Encoder {
@@ -437,4 +562,124 @@ impl Encoder {
437
562
encoders,
438
563
} )
439
564
}
565
+
566
+ pub fn forward ( & self , xs : & Tensor , mask : Option < & Tensor > ) -> Result < ( Tensor , Option < Tensor > ) > {
567
+ // Forward through embeddings (subsampling)
568
+ let ( mut input_tensor, masks) = self . embed . forward ( xs, mask) ?;
569
+
570
+ // Handle long sequences with unfolding
571
+ let max_seq_len = 500 ;
572
+ let ( ori_bz, seq_len, d) = input_tensor. dims3 ( ) ?;
573
+ let unfolded = seq_len > max_seq_len;
574
+
575
+ // Outside of the `if` block as it's needed below
576
+ let mut chunk_pad_size = 0 ;
577
+ if unfolded {
578
+ // Pad to multiple of max_seq_len
579
+ chunk_pad_size = if seq_len % max_seq_len > 0 {
580
+ max_seq_len - ( seq_len % max_seq_len)
581
+ } else {
582
+ 0
583
+ } ;
584
+
585
+ if chunk_pad_size > 0 {
586
+ input_tensor = input_tensor. pad_with_zeros ( D :: Minus1 , 0 , chunk_pad_size) ?;
587
+ }
588
+
589
+ // Unfold into chunks
590
+ input_tensor = unfold_tensor ( & input_tensor, max_seq_len) ?;
591
+ }
592
+
593
+ // Apply positional encoding
594
+ input_tensor = self . pos_embed . forward ( & input_tensor) ?;
595
+
596
+ // Compute relative attention bias if available
597
+ let relative_attention_bias = self . relative_attention_bias_layer . forward ( & input_tensor) ?;
598
+
599
+ // Apply encoder layers
600
+ for layer in & self . encoders {
601
+ input_tensor = layer. forward (
602
+ & input_tensor,
603
+ masks. as_ref ( ) ,
604
+ Some ( & relative_attention_bias) ,
605
+ ) ?;
606
+ }
607
+
608
+ // Handle unfolding restoration
609
+ if unfolded {
610
+ input_tensor = input_tensor. reshape ( ( ori_bz, seq_len + chunk_pad_size, d) ) ?;
611
+ if chunk_pad_size > 0 {
612
+ input_tensor = input_tensor. i ( ( .., ..seq_len, ..) ) ?;
613
+ }
614
+ }
615
+
616
+ Ok ( ( input_tensor, masks) )
617
+ }
618
+ }
619
+
620
+ fn unfold_tensor ( xs_pad : & Tensor , max_seq_len : usize ) -> Result < Tensor > {
621
+ let ( n, t, d) = xs_pad. dims3 ( ) ?;
622
+
623
+ // If sequence length is already <= max_seq_len, no need to unfold
624
+ if t <= max_seq_len {
625
+ return Ok ( xs_pad. clone ( ) ) ;
626
+ }
627
+
628
+ // xs_pad.transpose(-1, -2) # convert to N, D, T
629
+ let xs_pad = xs_pad. transpose ( 1 , 2 ) ?; // (N, T, D) -> (N, D, T)
630
+
631
+ // Unfold the last dimension (T) with size=max_seq_len and step=max_seq_len
632
+ // This creates sliding windows of size max_seq_len with step max_seq_len
633
+ let xs_pad = xs_pad. unfold ( 2 , max_seq_len, max_seq_len) ?;
634
+ // Shape is now (N, D, T', max_seq_len) where T' = T // max_seq_len
635
+
636
+ let ( n, d, t_prime, _) = xs_pad. dims4 ( ) ?;
637
+
638
+ // Permute to (N, T', max_seq_len, D) - equivalent to permute(0, 2, 3, 1)
639
+ let xs_pad = xs_pad. permute ( ( 0 , 2 , 3 , 1 ) ) ?;
640
+
641
+ // Reshape to (N*T', max_seq_len, D)
642
+ let xs_pad = xs_pad. reshape ( ( n * t_prime, max_seq_len, d) ) ?;
643
+
644
+ Ok ( xs_pad)
645
+ }
646
+
647
+ #[ cfg( test) ]
648
+ mod tests {
649
+ use super :: * ;
650
+ use candle_core:: Device ;
651
+
652
+ #[ test]
653
+ fn test_unfold_tensor ( ) -> Result < ( ) > {
654
+ let device = Device :: Cpu ;
655
+
656
+ // Test case 1: T > max_seq_len
657
+ let xs = Tensor :: arange ( 0f32 , 24f32 , & device) ?. reshape ( ( 2 , 6 , 2 ) ) ?; // (N=2, T=6, D=2)
658
+ let result = unfold_tensor ( & xs, 3 ) ?; // max_seq_len=3
659
+ assert_eq ! ( result. dims( ) , & [ 4 , 3 , 2 ] ) ; // (N*T'=2*2, max_seq_len=3, D=2)
660
+
661
+ // Test case 2: T <= max_seq_len
662
+ let xs = Tensor :: arange ( 0f32 , 12f32 , & device) ?. reshape ( ( 2 , 3 , 2 ) ) ?; // (N=2, T=3, D=2)
663
+ let result = unfold_tensor ( & xs, 5 ) ?; // max_seq_len=5
664
+ assert_eq ! ( result. dims( ) , & [ 2 , 3 , 2 ] ) ; // Should return original shape
665
+
666
+ // Test case 3: T == max_seq_len
667
+ let xs = Tensor :: arange ( 0f32 , 12f32 , & device) ?. reshape ( ( 2 , 3 , 2 ) ) ?; // (N=2, T=3, D=2)
668
+ let result = unfold_tensor ( & xs, 3 ) ?; // max_seq_len=3
669
+ assert_eq ! ( result. dims( ) , & [ 2 , 3 , 2 ] ) ; // (N*T'=2*1, max_seq_len=3, D=2)
670
+
671
+ Ok ( ( ) )
672
+ }
673
+
674
+ #[ test]
675
+ fn test_unfold_tensor_larger ( ) -> Result < ( ) > {
676
+ let device = Device :: Cpu ;
677
+
678
+ // Test with larger tensor
679
+ let xs = Tensor :: arange ( 0f32 , 120f32 , & device) ?. reshape ( ( 2 , 10 , 6 ) ) ?; // (N=2, T=10, D=6)
680
+ let result = unfold_tensor ( & xs, 4 ) ?; // max_seq_len=4, T'=10//4=2
681
+ assert_eq ! ( result. dims( ) , & [ 4 , 4 , 6 ] ) ; // (N*T'=2*2, max_seq_len=4, D=6)
682
+
683
+ Ok ( ( ) )
684
+ }
440
685
}
0 commit comments