Skip to content

Commit e5f88ae

Browse files
IlyasMoutawwakilArthurZucker
authored andcommitted
Fix is_causal being a tensor (#35791)
* fix is_causal being a tensor * convert in sdpa attention only when jit tracing
1 parent 163c8bb commit e5f88ae

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/transformers/integrations/sdpa_attention.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def sdpa_attention_forward(
4545
if is_causal is None:
4646
is_causal = causal_mask is None and query.shape[2] > 1
4747

48+
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
49+
# We convert it to a bool for the SDPA kernel that only accepts bools.
50+
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
51+
is_causal = is_causal.item()
52+
4853
attn_output = torch.nn.functional.scaled_dot_product_attention(
4954
query,
5055
key,

0 commit comments

Comments
 (0)