Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit a1476c0

Browse files
Eric-Wallacematt-gardner
authored andcommitted
add equality check for index field; allennlp interpret (#3073)
* add equality check for index field; allennlp interpret * add test * change hotflip to use equals method * tests per matt * newline * change input reduction to eq also * undo * add newline * fix pylutn
1 parent 5014d02 commit a1476c0

File tree

4 files changed

+57
-5
lines changed

4 files changed

+57
-5
lines changed

allennlp/data/fields/index_field.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,4 @@ def __eq__(self, other) -> bool:
5656
# Allow equality checks to ints that are the sequence index
5757
if isinstance(other, int):
5858
return self.sequence_index == other
59-
# Otherwise it has to be the same object
60-
else:
61-
return id(other) == id(self)
59+
return super().__eq__(other)

allennlp/data/tokenizers/character_tokenizer.py

+5
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,8 @@ def tokenize(self, text: str) -> List[Token]:
7373
token = Token(text=end_token, idx=0)
7474
tokens.append(token)
7575
return tokens
76+
77+
def __eq__(self, other) -> bool:
78+
if isinstance(self, other.__class__):
79+
return self.__dict__ == other.__dict__
80+
return NotImplemented

allennlp/tests/data/fields/index_field_test.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ def test_printing_doesnt_crash(self):
3535
def test_equality(self):
3636
index_field1 = IndexField(4, self.text)
3737
index_field2 = IndexField(4, self.text)
38+
index_field3 = IndexField(4, TextField([Token(t) for t in ["AllenNLP", "is", "the", "bomb", "!"]],
39+
{"words": SingleIdTokenIndexer("words")}))
3840

3941
assert index_field1 == 4
4042
assert index_field1 == index_field1
41-
assert index_field1 != index_field2
43+
assert index_field1 == index_field2
44+
45+
assert index_field1 != index_field3
46+
assert index_field2 != index_field3
47+
assert index_field3 == index_field3

allennlp/tests/interpret/hotflip_test.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=no-self-use,invalid-name
1+
# pylint: disable=no-self-use,invalid-name,protected-access
22
from allennlp.common.testing import AllenNlpTestCase
33
from allennlp.models.archival import load_archive
44
from allennlp.predictors import Predictor
@@ -22,3 +22,46 @@ def test_hotflip(self):
2222
assert 'original' in attack
2323
assert 'outputs' in attack
2424
assert len(attack['final'][0]) == len(attack['original']) # hotflip replaces words without removing
25+
26+
# test using SQuAD model (tests different equals method)
27+
inputs = {
28+
"question": "OMG, I heard you coded a test that succeeded on its first attempt, is that true?",
29+
"passage": "Bro, never doubt a coding wizard! I am the king of software, MWAHAHAHA"
30+
}
31+
32+
archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz')
33+
predictor = Predictor.from_archive(archive, 'machine-comprehension')
34+
35+
hotflipper = Hotflip(predictor)
36+
hotflipper.initialize()
37+
ignore_tokens = ["@@NULL@@", '.', ',', ';', '!', '?']
38+
attack = hotflipper.attack_from_json(inputs,
39+
'question',
40+
'grad_input_2')
41+
assert attack is not None
42+
assert 'final' in attack
43+
assert 'original' in attack
44+
assert 'outputs' in attack
45+
assert len(attack['final'][0]) == len(attack['original']) # hotflip replaces words without removing
46+
47+
instance = predictor._json_to_instance(inputs)
48+
assert instance['question'] != attack['final'][0] # check that the input has changed.
49+
50+
outputs = predictor._model.forward_on_instance(instance)
51+
original_labeled_instance = predictor.predictions_to_labeled_instances(instance, outputs)[0]
52+
original_span_start = original_labeled_instance['span_start'].sequence_index
53+
original_span_end = original_labeled_instance['span_end'].sequence_index
54+
55+
flipped_span_start = attack['outputs']['best_span'][0]
56+
flipped_span_end = attack['outputs']['best_span'][1]
57+
58+
for token in instance['question']:
59+
token = str(token)
60+
if token in ignore_tokens:
61+
assert token in attack['final'][0] # ignore tokens should not be changed
62+
# HotFlip keeps changing tokens until either the predictions changes or all tokens have
63+
# been changed. If there are tokens in the HotFlip final result that were in the original
64+
# (i.e., not all tokens were flipped), then the prediction should be different.
65+
else:
66+
if token in attack['final'][0]:
67+
assert original_span_start != flipped_span_start or original_span_end != flipped_span_end

0 commit comments

Comments
 (0)