1
- # pylint: disable=no-self-use,invalid-name
1
+ # pylint: disable=no-self-use,invalid-name,protected-access
2
2
from allennlp .common .testing import AllenNlpTestCase
3
3
from allennlp .models .archival import load_archive
4
4
from allennlp .predictors import Predictor
@@ -22,3 +22,46 @@ def test_hotflip(self):
22
22
assert 'original' in attack
23
23
assert 'outputs' in attack
24
24
assert len (attack ['final' ][0 ]) == len (attack ['original' ]) # hotflip replaces words without removing
25
+
26
+ # test using SQuAD model (tests different equals method)
27
+ inputs = {
28
+ "question" : "OMG, I heard you coded a test that succeeded on its first attempt, is that true?" ,
29
+ "passage" : "Bro, never doubt a coding wizard! I am the king of software, MWAHAHAHA"
30
+ }
31
+
32
+ archive = load_archive (self .FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz' )
33
+ predictor = Predictor .from_archive (archive , 'machine-comprehension' )
34
+
35
+ hotflipper = Hotflip (predictor )
36
+ hotflipper .initialize ()
37
+ ignore_tokens = ["@@NULL@@" , '.' , ',' , ';' , '!' , '?' ]
38
+ attack = hotflipper .attack_from_json (inputs ,
39
+ 'question' ,
40
+ 'grad_input_2' )
41
+ assert attack is not None
42
+ assert 'final' in attack
43
+ assert 'original' in attack
44
+ assert 'outputs' in attack
45
+ assert len (attack ['final' ][0 ]) == len (attack ['original' ]) # hotflip replaces words without removing
46
+
47
+ instance = predictor ._json_to_instance (inputs )
48
+ assert instance ['question' ] != attack ['final' ][0 ] # check that the input has changed.
49
+
50
+ outputs = predictor ._model .forward_on_instance (instance )
51
+ original_labeled_instance = predictor .predictions_to_labeled_instances (instance , outputs )[0 ]
52
+ original_span_start = original_labeled_instance ['span_start' ].sequence_index
53
+ original_span_end = original_labeled_instance ['span_end' ].sequence_index
54
+
55
+ flipped_span_start = attack ['outputs' ]['best_span' ][0 ]
56
+ flipped_span_end = attack ['outputs' ]['best_span' ][1 ]
57
+
58
+ for token in instance ['question' ]:
59
+ token = str (token )
60
+ if token in ignore_tokens :
61
+ assert token in attack ['final' ][0 ] # ignore tokens should not be changed
62
+ # HotFlip keeps changing tokens until either the predictions changes or all tokens have
63
+ # been changed. If there are tokens in the HotFlip final result that were in the original
64
+ # (i.e., not all tokens were flipped), then the prediction should be different.
65
+ else :
66
+ if token in attack ['final' ][0 ]:
67
+ assert original_span_start != flipped_span_start or original_span_end != flipped_span_end
0 commit comments