File tree Expand file tree Collapse file tree 1 file changed +13
-2
lines changed
mistralrs-core/src/vision_models/conformer Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -71,7 +71,6 @@ impl Attention {
71
71
attention_mask : Option < & Tensor > ,
72
72
relative_attention_bias : Option < & Tensor > ,
73
73
) -> Result < Tensor > {
74
- todo ! ( "relative_attention_bias" ) ;
75
74
let ( b_sz, q_len, _) = xs. dims3 ( ) ?;
76
75
77
76
let mut q = self . q_proj . forward ( xs) ?;
@@ -88,11 +87,23 @@ impl Attention {
88
87
. reshape ( ( b_sz, q_len, self . num_heads , self . head_dim ) ) ?
89
88
. transpose ( 1 , 2 ) ?;
90
89
90
+ let attention_mask = match ( attention_mask, relative_attention_bias) {
91
+ ( Some ( attention_mask) , Some ( relative_attention_bias) ) => Some (
92
+ attention_mask
93
+ . unsqueeze ( 1 ) ?
94
+ . broadcast_add ( relative_attention_bias) ?,
95
+ ) ,
96
+ ( Some ( attention_mask) , None ) => Some ( attention_mask. unsqueeze ( 1 ) ?) ,
97
+ ( None , None ) => None ,
98
+ ( None , Some ( _) ) => {
99
+ candle_core:: bail!( "Got `relative_attention_bias` but no `attention_mask`" )
100
+ }
101
+ } ;
91
102
let attn_weights = Sdpa . run_attention (
92
103
& q,
93
104
& k,
94
105
& v,
95
- attention_mask,
106
+ attention_mask. as_ref ( ) ,
96
107
None ,
97
108
& SdpaParams {
98
109
n_kv_groups : 1 ,
You can’t perform that action at this time.
0 commit comments