Skip to content

Commit 8244e07

Browse files
committed
Implement relative attn bias
1 parent 4cd9196 commit 8244e07

File tree

1 file changed

+13
-2
lines changed
  • mistralrs-core/src/vision_models/conformer

1 file changed

+13
-2
lines changed

mistralrs-core/src/vision_models/conformer/encoder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ impl Attention {
7171
attention_mask: Option<&Tensor>,
7272
relative_attention_bias: Option<&Tensor>,
7373
) -> Result<Tensor> {
74-
todo!("relative_attention_bias");
7574
let (b_sz, q_len, _) = xs.dims3()?;
7675

7776
let mut q = self.q_proj.forward(xs)?;
@@ -88,11 +87,23 @@ impl Attention {
8887
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
8988
.transpose(1, 2)?;
9089

90+
let attention_mask = match (attention_mask, relative_attention_bias) {
91+
(Some(attention_mask), Some(relative_attention_bias)) => Some(
92+
attention_mask
93+
.unsqueeze(1)?
94+
.broadcast_add(relative_attention_bias)?,
95+
),
96+
(Some(attention_mask), None) => Some(attention_mask.unsqueeze(1)?),
97+
(None, None) => None,
98+
(None, Some(_)) => {
99+
candle_core::bail!("Got `relative_attention_bias` but no `attention_mask`")
100+
}
101+
};
91102
let attn_weights = Sdpa.run_attention(
92103
&q,
93104
&k,
94105
&v,
95-
attention_mask,
106+
attention_mask.as_ref(),
96107
None,
97108
&SdpaParams {
98109
n_kv_groups: 1,

0 commit comments

Comments
 (0)