Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit bbb67e9

Browse files
authored
Add dataset reader for DROP (#2556)
* Add dataset reader for DROP * Add missing dependency * Fix pylint, mypy and docs * Add some more tests * Address PR feedback * mypy again
1 parent 4d5eade commit bbb67e9

File tree

8 files changed

+987
-2
lines changed

8 files changed

+987
-2
lines changed

allennlp/data/dataset_readers/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from allennlp.data.dataset_readers.language_modeling import LanguageModelingReader
1919
from allennlp.data.dataset_readers.multiprocess_dataset_reader import MultiprocessDatasetReader
2020
from allennlp.data.dataset_readers.penn_tree_bank import PennTreeBankConstituencySpanDatasetReader
21-
from allennlp.data.dataset_readers.reading_comprehension import SquadReader, TriviaQaReader, QuACReader, QangarooReader
21+
from allennlp.data.dataset_readers.reading_comprehension import (
22+
DropReader, SquadReader, TriviaQaReader, QuACReader, QangarooReader)
2223
from allennlp.data.dataset_readers.semantic_role_labeling import SrlReader
2324
from allennlp.data.dataset_readers.semantic_dependency_parsing import SemanticDependenciesDatasetReader
2425
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader

allennlp/data/dataset_readers/reading_comprehension/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
These submodules contain readers for things that are predominantly reading comprehension datasets.
66
"""
77

8+
from allennlp.data.dataset_readers.reading_comprehension.drop import DropReader
89
from allennlp.data.dataset_readers.reading_comprehension.squad import SquadReader
910
from allennlp.data.dataset_readers.reading_comprehension.quac import QuACReader
1011
from allennlp.data.dataset_readers.reading_comprehension.triviaqa import TriviaQaReader

allennlp/data/dataset_readers/reading_comprehension/drop.py

+516
Large diffs are not rendered by default.

allennlp/data/dataset_readers/reading_comprehension/util.py

+39
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,42 @@ def handle_cannot(reference_answers: List[str]):
369369
else:
370370
reference_answers = [x for x in reference_answers if x != 'CANNOTANSWER']
371371
return reference_answers
372+
373+
374+
def split_token_by_delimiter(token: Token, delimiter: str) -> List[Token]:
375+
split_tokens = []
376+
char_offset = token.idx
377+
for sub_str in token.text.split(delimiter):
378+
if sub_str:
379+
split_tokens.append(Token(text=sub_str, idx=char_offset))
380+
char_offset += len(sub_str)
381+
split_tokens.append(Token(text=delimiter, idx=char_offset))
382+
char_offset += len(delimiter)
383+
if split_tokens:
384+
split_tokens.pop(-1)
385+
char_offset -= len(delimiter)
386+
return split_tokens
387+
else:
388+
return [token]
389+
390+
391+
def split_tokens_by_hyphen(tokens: List[Token]) -> List[Token]:
392+
hyphens = ["-", "–", "~"]
393+
new_tokens: List[Token] = []
394+
395+
for token in tokens:
396+
if any(hyphen in token.text for hyphen in hyphens):
397+
unsplit_tokens = [token]
398+
split_tokens: List[Token] = []
399+
for hyphen in hyphens:
400+
for unsplit_token in unsplit_tokens:
401+
if hyphen in token.text:
402+
split_tokens += split_token_by_delimiter(unsplit_token, hyphen)
403+
else:
404+
split_tokens.append(unsplit_token)
405+
unsplit_tokens, split_tokens = split_tokens, []
406+
new_tokens += unsplit_tokens
407+
else:
408+
new_tokens.append(token)
409+
410+
return new_tokens
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# pylint: disable=no-self-use,invalid-name, protected-access
2+
import pytest
3+
4+
from allennlp.common import Params
5+
from allennlp.common.util import ensure_list
6+
from allennlp.data.dataset_readers import DropReader
7+
from allennlp.common.testing import AllenNlpTestCase
8+
9+
10+
class TestDropReader:
11+
@pytest.mark.parametrize("lazy", (True, False))
12+
def test_read_from_file(self, lazy):
13+
reader = DropReader(lazy=lazy)
14+
instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'drop.json'))
15+
assert len(instances) == 19
16+
17+
instance = instances[0]
18+
assert set(instance.fields.keys()) == {
19+
'question',
20+
'passage',
21+
'number_indices',
22+
'numbers_in_passage',
23+
'answer_as_passage_spans',
24+
'answer_as_question_spans',
25+
'answer_as_add_sub_expressions',
26+
'answer_as_counts',
27+
'metadata',
28+
}
29+
30+
assert [t.text for t in instance["question"][:3]] == ["What", "happened", "second"]
31+
assert [t.text for t in instance["passage"][:3]] == ["The", "Port", "of"]
32+
assert [t.text for t in instance["passage"][-3:]] == ["cruise", "ships", "."]
33+
34+
# Note that the last number in here is added as padding in case we don't find any numbers
35+
# in a particular passage.
36+
assert [f.sequence_index for f in instance["number_indices"]] == [
37+
16, 30, 36, 41, 52, 64, 80, 89, 147, 153, 166, 174, 177, 206, 245, 252, 267, 279,
38+
283, 288, 296, -1
39+
]
40+
assert [t.text for t in instance["numbers_in_passage"]] == [
41+
"1", "25", "2014", "5", "2018", "1", "2", "1", "54", "52", "6", "60", "58", "2010",
42+
"67", "2010", "1996", "3", "1", "6", "1", "0"]
43+
assert len(instance["answer_as_passage_spans"]) == 1
44+
assert instance["answer_as_passage_spans"][0] == (46, 47)
45+
assert len(instance["answer_as_question_spans"]) == 1
46+
assert instance["answer_as_question_spans"][0] == (5, 6)
47+
assert len(instance["answer_as_add_sub_expressions"]) == 1
48+
assert instance["answer_as_add_sub_expressions"][0].labels == [0,] * 22
49+
assert len(instance["answer_as_counts"]) == 1
50+
assert instance["answer_as_counts"][0].label == -1
51+
assert set(instance['metadata'].metadata.keys()) == {
52+
'answer_annotations',
53+
'answer_info',
54+
'answer_texts',
55+
'number_indices',
56+
'number_tokens',
57+
'original_numbers',
58+
'original_passage',
59+
'original_question',
60+
'passage_id',
61+
'passage_token_offsets',
62+
'passage_tokens',
63+
'question_id',
64+
'question_token_offsets',
65+
'question_tokens',
66+
}
67+
68+
def test_read_in_bert_format(self):
69+
reader = DropReader(instance_format="bert")
70+
instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'drop.json'))
71+
assert len(instances) == 19
72+
73+
print(instances[0])
74+
instance = instances[0]
75+
assert set(instance.fields.keys()) == {
76+
'answer_as_passage_spans',
77+
'metadata',
78+
'passage',
79+
'question',
80+
'question_and_passage',
81+
}
82+
83+
assert [t.text for t in instance["question"][:3]] == ["What", "happened", "second"]
84+
assert [t.text for t in instance["passage"][:3]] == ["The", "Port", "of"]
85+
assert [t.text for t in instance["passage"][-3:]] == ["cruise", "ships", "."]
86+
question_length = len(instance['question'])
87+
passage_length = len(instance['passage'])
88+
assert len(instance['question_and_passage']) == question_length + passage_length + 1
89+
90+
assert len(instance["answer_as_passage_spans"]) == 1
91+
assert instance["answer_as_passage_spans"][0] == (question_length + 1 + 46,
92+
question_length + 1 + 47)
93+
assert set(instance['metadata'].metadata.keys()) == {
94+
'answer_annotations',
95+
'answer_texts',
96+
'original_passage',
97+
'original_question',
98+
'passage_id',
99+
'passage_token_offsets',
100+
'passage_tokens',
101+
'question_id',
102+
'question_tokens',
103+
}
104+
105+
def test_read_in_squad_format(self):
106+
reader = DropReader(instance_format="squad")
107+
instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'drop.json'))
108+
assert len(instances) == 19
109+
110+
print(instances[0])
111+
instance = instances[0]
112+
assert set(instance.fields.keys()) == {
113+
'question',
114+
'passage',
115+
'span_start',
116+
'span_end',
117+
'metadata',
118+
}
119+
120+
assert [t.text for t in instance["question"][:3]] == ["What", "happened", "second"]
121+
assert [t.text for t in instance["passage"][:3]] == ["The", "Port", "of"]
122+
assert [t.text for t in instance["passage"][-3:]] == ["cruise", "ships", "."]
123+
124+
assert instance["span_start"] == 46
125+
assert instance["span_end"] == 47
126+
assert set(instance['metadata'].metadata.keys()) == {
127+
'answer_annotations',
128+
'answer_texts',
129+
'original_passage',
130+
'original_question',
131+
'passage_id',
132+
'token_offsets',
133+
'passage_tokens',
134+
'question_id',
135+
'question_tokens',
136+
'valid_passage_spans',
137+
}
138+
139+
def test_can_build_from_params(self):
140+
reader = DropReader.from_params(Params({}))
141+
assert reader._tokenizer.__class__.__name__ == 'WordTokenizer'
142+
assert reader._token_indexers["tokens"].__class__.__name__ == 'SingleIdTokenIndexer'

0 commit comments

Comments
 (0)