Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit a30dd77

Browse files
committed
fix distilbert test
1 parent 778b81a commit a30dd77

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/modules/transformer/self_attention_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relev
7171
seq_len = 3
7272
dim = module.query.in_features
7373
hidden_states = torch.randn(batch_size, seq_len, dim)
74-
attention_mask = torch.randint(0, 2, (batch_size, 1, 1, seq_len))
74+
attention_mask = torch.tensor([[1, 1, 0], [1, 0, 1]])[:, None, None, :]
7575

7676
# setting to eval mode to avoid non-deterministic dropout.
7777
module = module.eval()

0 commit comments

Comments
 (0)