Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 4df4638

Browse files
eunsolDeNeutoy
authored andcommitted
Data reader for QUAC (#1624)
* quac initial commit * random files * l * pull * pushing * Attention version, but seems to be buggy.. ( no effect on scores ) * all * pushing * upload * DQA SEQ working version * Added followup info * fixed V2 bug * fixed V2 bug * question num marker * push to taranis * fix eval to match official script * fixing eval script * fixing eval * eval mismatch bug * push to taranis * final? * remote data * removing * preparing pull request * removing diffs * removing query addition * mask bug fixed * cpu support * fixed cuda * making dataset only pull request * adding back some files * removing tag * moved * added api doc * adding dataset rst * addressed comments from deneutoy * fixed some lint errors * fixed doc api, wrong continuation from pylint * fixing mypy errors * fixing one more mypy error * removing repeated doc * fixed mypy error * last minute changes * indented the data
1 parent bf75c9b commit 4df4638

File tree

7 files changed

+521
-8
lines changed

7 files changed

+521
-8
lines changed

allennlp/data/dataset_readers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from allennlp.data.dataset_readers.language_modeling import LanguageModelingReader
1717
from allennlp.data.dataset_readers.nlvr import NlvrDatasetReader
1818
from allennlp.data.dataset_readers.penn_tree_bank import PennTreeBankConstituencySpanDatasetReader
19-
from allennlp.data.dataset_readers.reading_comprehension import SquadReader, TriviaQaReader
19+
from allennlp.data.dataset_readers.reading_comprehension import SquadReader, TriviaQaReader, QuACReader
2020
from allennlp.data.dataset_readers.semantic_role_labeling import SrlReader
2121
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader
2222
from allennlp.data.dataset_readers.sequence_tagging import SequenceTaggingDatasetReader

allennlp/data/dataset_readers/reading_comprehension/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
"""
77

88
from allennlp.data.dataset_readers.reading_comprehension.squad import SquadReader
9+
from allennlp.data.dataset_readers.reading_comprehension.quac import QuACReader
910
from allennlp.data.dataset_readers.reading_comprehension.triviaqa import TriviaQaReader
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import json
2+
import logging
3+
from typing import Any, Dict, List, Tuple
4+
5+
from overrides import overrides
6+
7+
from allennlp.common.file_utils import cached_path
8+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9+
from allennlp.data.instance import Instance
10+
from allennlp.data.dataset_readers.reading_comprehension import util
11+
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
12+
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
13+
14+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
15+
16+
17+
@DatasetReader.register("quac")
18+
class QuACReader(DatasetReader):
19+
"""
20+
Reads a JSON-formatted Quesiton Answering in Context (QuAC) data file
21+
and returns a ``Dataset`` where the ``Instances`` have four fields: ``question``, a ``ListField``,
22+
``passage``, another ``TextField``, and ``span_start`` and ``span_end``, both ``ListField`` composed of
23+
IndexFields`` into the ``passage`` ``TextField``.
24+
Two ``ListField``, composed of ``LabelField``, ``yesno_list`` and ``followup_list`` is added.
25+
We also add a
26+
``MetadataField`` that stores the instance's ID, the original passage text, gold answer strings,
27+
and token offsets into the original passage, accessible as ``metadata['id']``,
28+
``metadata['original_passage']``, ``metadata['answer_text_lists'] and ``metadata['token_offsets']``.
29+
30+
Parameters
31+
----------
32+
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``)
33+
We use this ``Tokenizer`` for both the question and the passage. See :class:`Tokenizer`.
34+
Default is ```WordTokenizer()``.
35+
token_indexers : ``Dict[str, TokenIndexer]``, optional
36+
We similarly use this for both the question and the passage. See :class:`TokenIndexer`.
37+
Default is ``{"tokens": SingleIdTokenIndexer()}``.
38+
num_context_answers : ``int``, optional
39+
How many previous question answers to consider in a context.
40+
"""
41+
42+
def __init__(self,
43+
tokenizer: Tokenizer = None,
44+
token_indexers: Dict[str, TokenIndexer] = None,
45+
lazy: bool = False,
46+
num_context_answers: int = 0) -> None:
47+
super().__init__(lazy)
48+
self._tokenizer = tokenizer or WordTokenizer()
49+
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
50+
self._num_context_answers = num_context_answers
51+
52+
@overrides
53+
def _read(self, file_path: str):
54+
# if `file_path` is a URL, redirect to the cache
55+
file_path = cached_path(file_path)
56+
logger.info("Reading file at %s", file_path)
57+
with open(file_path) as dataset_file:
58+
dataset_json = json.load(dataset_file)
59+
dataset = dataset_json['data']
60+
logger.info("Reading the dataset")
61+
for article in dataset:
62+
for paragraph_json in article['paragraphs']:
63+
paragraph = paragraph_json["context"]
64+
tokenized_paragraph = self._tokenizer.tokenize(paragraph)
65+
qas = paragraph_json['qas']
66+
metadata = {}
67+
metadata["instance_id"] = [qa['id'] for qa in qas]
68+
question_text_list = [qa["question"].strip().replace("\n", "") for qa in qas]
69+
answer_texts_list = [[answer['text'] for answer in qa['answers']] for qa in qas]
70+
metadata["question"] = question_text_list
71+
metadata['answer_texts_list'] = answer_texts_list
72+
span_starts_list = [[answer['answer_start'] for answer in qa['answers']] for qa in qas]
73+
span_ends_list = []
74+
for answer_starts, an_list in zip(span_starts_list, answer_texts_list):
75+
span_ends = [start + len(answer) for start, answer in zip(answer_starts, an_list)]
76+
span_ends_list.append(span_ends)
77+
yesno_list = [str(qa['yesno']) for qa in qas]
78+
followup_list = [str(qa['followup']) for qa in qas]
79+
instance = self.text_to_instance(question_text_list,
80+
paragraph,
81+
span_starts_list,
82+
span_ends_list,
83+
tokenized_paragraph,
84+
yesno_list,
85+
followup_list,
86+
metadata)
87+
yield instance
88+
89+
@overrides
90+
def text_to_instance(self, # type: ignore
91+
question_text_list: List[str],
92+
passage_text: str,
93+
start_span_list: List[List[int]] = None,
94+
end_span_list: List[List[int]] = None,
95+
passage_tokens: List[Token] = None,
96+
yesno_list: List[int] = None,
97+
followup_list: List[int] = None,
98+
additional_metadata: Dict[str, Any] = None) -> Instance:
99+
# pylint: disable=arguments-differ
100+
# We need to convert character indices in `passage_text` to token indices in
101+
# `passage_tokens`, as the latter is what we'll actually use for supervision.
102+
answer_token_span_list = []
103+
passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
104+
for start_list, end_list in zip(start_span_list, end_span_list):
105+
token_spans: List[Tuple[int, int]] = []
106+
for char_span_start, char_span_end in zip(start_list, end_list):
107+
(span_start, span_end), error = util.char_span_to_token_span(passage_offsets,
108+
(char_span_start, char_span_end))
109+
if error:
110+
logger.debug("Passage: %s", passage_text)
111+
logger.debug("Passage tokens: %s", passage_tokens)
112+
logger.debug("Answer span: (%d, %d)", char_span_start, char_span_end)
113+
logger.debug("Token span: (%d, %d)", span_start, span_end)
114+
logger.debug("Tokens in answer: %s", passage_tokens[span_start:span_end + 1])
115+
logger.debug("Answer: %s", passage_text[char_span_start:char_span_end])
116+
token_spans.append((span_start, span_end))
117+
answer_token_span_list.append(token_spans)
118+
question_list_tokens = [self._tokenizer.tokenize(q) for q in question_text_list]
119+
# Map answer texts to "CANNOTANSWER" if more than half of them marked as so.
120+
additional_metadata['answer_texts_list'] = [util.handle_cannot(ans_list) for ans_list \
121+
in additional_metadata['answer_texts_list']]
122+
return util.make_reading_comprehension_instance_quac(question_list_tokens,
123+
passage_tokens,
124+
self._token_indexers,
125+
passage_text,
126+
answer_token_span_list,
127+
yesno_list,
128+
followup_list,
129+
additional_metadata,
130+
self._num_context_answers)

allennlp/data/dataset_readers/reading_comprehension/util.py

+163-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import string
88
from typing import Any, Dict, List, Tuple
99

10-
from allennlp.data.fields import Field, TextField, IndexField, MetadataField
10+
from allennlp.data.fields import Field, TextField, IndexField, \
11+
MetadataField, LabelField, ListField, SequenceLabelField
1112
from allennlp.data.instance import Instance
1213
from allennlp.data.token_indexers import TokenIndexer
1314
from allennlp.data.tokenizers import Token
@@ -19,6 +20,7 @@
1920
IGNORED_TOKENS = {'a', 'an', 'the'}
2021
STRIPPED_CHARACTERS = string.punctuation + ''.join([u"‘", u"’", u"´", u"`", "_"])
2122

23+
2224
def normalize_text(text: str) -> str:
2325
"""
2426
Performs a normalization that is very similar to that done by the normalization functions in
@@ -187,12 +189,9 @@ def make_reading_comprehension_instance(question_tokens: List[Token],
187189
passage_field = TextField(passage_tokens, token_indexers)
188190
fields['passage'] = passage_field
189191
fields['question'] = TextField(question_tokens, token_indexers)
190-
metadata = {
191-
'original_passage': passage_text,
192-
'token_offsets': passage_offsets,
193-
'question_tokens': [token.text for token in question_tokens],
194-
'passage_tokens': [token.text for token in passage_tokens],
195-
}
192+
metadata = {'original_passage': passage_text, 'token_offsets': passage_offsets,
193+
'question_tokens': [token.text for token in question_tokens],
194+
'passage_tokens': [token.text for token in passage_tokens], }
196195
if answer_texts:
197196
metadata['answer_texts'] = answer_texts
198197

@@ -213,3 +212,160 @@ def make_reading_comprehension_instance(question_tokens: List[Token],
213212
metadata.update(additional_metadata)
214213
fields['metadata'] = MetadataField(metadata)
215214
return Instance(fields)
215+
216+
217+
def make_reading_comprehension_instance_quac(question_list_tokens: List[List[Token]],
218+
passage_tokens: List[Token],
219+
token_indexers: Dict[str, TokenIndexer],
220+
passage_text: str,
221+
token_span_lists: List[List[Tuple[int, int]]] = None,
222+
yesno_list: List[int] = None,
223+
followup_list: List[int] = None,
224+
additional_metadata: Dict[str, Any] = None,
225+
num_context_answers: int = 0) -> Instance:
226+
"""
227+
Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use
228+
in a reading comprehension model.
229+
230+
Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both
231+
``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts``
232+
and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end``
233+
fields, which are both ``IndexFields``.
234+
235+
Parameters
236+
----------
237+
question_list_tokens : ``List[List[Token]]``
238+
An already-tokenized list of questions. Each dialog have multiple questions.
239+
passage_tokens : ``List[Token]``
240+
An already-tokenized passage that contains the answer to the given question.
241+
token_indexers : ``Dict[str, TokenIndexer]``
242+
Determines how the question and passage ``TextFields`` will be converted into tensors that
243+
get input to a model. See :class:`TokenIndexer`.
244+
passage_text : ``str``
245+
The original passage text. We need this so that we can recover the actual span from the
246+
original passage that the model predicts as the answer to the question. This is used in
247+
official evaluation scripts.
248+
token_spans_lists : ``List[List[Tuple[int, int]]]``, optional
249+
Indices into ``passage_tokens`` to use as the answer to the question for training. This is
250+
a list of list, first because there is multiple questions per dialog, and
251+
because there might be several possible correct answer spans in the passage.
252+
Currently, we just select the last span in this list (i.e., QuAC has multiple
253+
annotations on the dev set; this will select the last span, which was given by the original annotator).
254+
yesno_list : ``List[int]``
255+
List of the affirmation bit for each question answer pairs.
256+
followup_list : ``List[int]``
257+
List of the continuation bit for each question answer pairs.
258+
num_context_answers : ``int``, optional
259+
How many answers to encode into the passage.
260+
additional_metadata : ``Dict[str, Any]``, optional
261+
The constructed ``metadata`` field will by default contain ``original_passage``,
262+
``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If
263+
you want any other metadata to be associated with each instance, you can pass that in here.
264+
This dictionary will get added to the ``metadata`` dictionary we already construct.
265+
"""
266+
additional_metadata = additional_metadata or {}
267+
fields: Dict[str, Field] = {}
268+
passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
269+
# This is separate so we can reference it later with a known type.
270+
passage_field = TextField(passage_tokens, token_indexers)
271+
fields['passage'] = passage_field
272+
fields['question'] = ListField([TextField(q_tokens, token_indexers) for q_tokens in question_list_tokens])
273+
metadata = {'original_passage': passage_text,
274+
'token_offsets': passage_offsets,
275+
'question_tokens': [[token.text for token in question_tokens] \
276+
for question_tokens in question_list_tokens],
277+
'passage_tokens': [token.text for token in passage_tokens], }
278+
p1_answer_marker_list: List[Field] = []
279+
p2_answer_marker_list: List[Field] = []
280+
p3_answer_marker_list: List[Field] = []
281+
282+
def get_tag(i, i_name):
283+
# Generate a tag to mark previous answer span in the passage.
284+
return "<{0:d}_{1:s}>".format(i, i_name)
285+
286+
def mark_tag(span_start, span_end, passage_tags, prev_answer_distance):
287+
try:
288+
assert span_start > 0
289+
assert span_end > 0
290+
except:
291+
raise ValueError("Previous {0:d}th answer span should have been updated!".format(prev_answer_distance))
292+
# Modify "tags" to mark previous answer span.
293+
if span_start == span_end:
294+
passage_tags[prev_answer_distance][span_start] = get_tag(prev_answer_distance, "")
295+
else:
296+
passage_tags[prev_answer_distance][span_start] = get_tag(prev_answer_distance, "start")
297+
passage_tags[prev_answer_distance][span_end] = get_tag(prev_answer_distance, "end")
298+
for passage_index in range(span_start + 1, span_end):
299+
passage_tags[prev_answer_distance][passage_index] = get_tag(prev_answer_distance, "in")
300+
301+
if token_span_lists:
302+
span_start_list: List[Field] = []
303+
span_end_list: List[Field] = []
304+
p1_span_start, p1_span_end, p2_span_start = -1, -1, -1
305+
p2_span_end, p3_span_start, p3_span_end = -1, -1, -1
306+
# Looping each <<answers>>.
307+
for question_index, answer_span_lists in enumerate(token_span_lists):
308+
span_start, span_end = answer_span_lists[-1] # Last one is the original answer
309+
span_start_list.append(IndexField(span_start, passage_field))
310+
span_end_list.append(IndexField(span_end, passage_field))
311+
prev_answer_marker_lists = [["O"] * len(passage_tokens), ["O"] * len(passage_tokens),
312+
["O"] * len(passage_tokens), ["O"] * len(passage_tokens)]
313+
if question_index > 0 and num_context_answers > 0:
314+
mark_tag(p1_span_start, p1_span_end, prev_answer_marker_lists, 1)
315+
if question_index > 1 and num_context_answers > 1:
316+
mark_tag(p2_span_start, p2_span_end, prev_answer_marker_lists, 2)
317+
if question_index > 2 and num_context_answers > 2:
318+
mark_tag(p3_span_start, p3_span_end, prev_answer_marker_lists, 3)
319+
p3_span_start = p2_span_start
320+
p3_span_end = p2_span_end
321+
p2_span_start = p1_span_start
322+
p2_span_end = p1_span_end
323+
p1_span_start = span_start
324+
p1_span_end = span_end
325+
if num_context_answers > 2:
326+
p3_answer_marker_list.append(SequenceLabelField(prev_answer_marker_lists[3],
327+
passage_field,
328+
label_namespace="answer_tags"))
329+
if num_context_answers > 1:
330+
p2_answer_marker_list.append(SequenceLabelField(prev_answer_marker_lists[2],
331+
passage_field,
332+
label_namespace="answer_tags"))
333+
if num_context_answers > 0:
334+
p1_answer_marker_list.append(SequenceLabelField(prev_answer_marker_lists[1],
335+
passage_field,
336+
label_namespace="answer_tags"))
337+
fields['span_start'] = ListField(span_start_list)
338+
fields['span_end'] = ListField(span_end_list)
339+
if num_context_answers > 0:
340+
fields['p1_answer_marker'] = ListField(p1_answer_marker_list)
341+
if num_context_answers > 1:
342+
fields['p2_answer_marker'] = ListField(p2_answer_marker_list)
343+
if num_context_answers > 2:
344+
fields['p3_answer_marker'] = ListField(p3_answer_marker_list)
345+
fields['yesno_list'] = ListField( \
346+
[LabelField(yesno, label_namespace="yesno_labels") for yesno in yesno_list])
347+
fields['followup_list'] = ListField([LabelField(followup, label_namespace="followup_labels") \
348+
for followup in followup_list])
349+
metadata.update(additional_metadata)
350+
fields['metadata'] = MetadataField(metadata)
351+
return Instance(fields)
352+
353+
354+
def handle_cannot(reference_answers: List[str]):
355+
"""
356+
Process a list of reference answers.
357+
If equal or more than half of the reference answers are "CANNOTANSWER", take it as gold.
358+
Otherwise, return answers that are not "CANNOTANSWER".
359+
"""
360+
num_cannot = 0
361+
num_spans = 0
362+
for ref in reference_answers:
363+
if ref == 'CANNOTANSWER':
364+
num_cannot += 1
365+
else:
366+
num_spans += 1
367+
if num_cannot >= num_spans:
368+
reference_answers = ['CANNOTANSWER']
369+
else:
370+
reference_answers = [x for x in reference_answers if x != 'CANNOTANSWER']
371+
return reference_answers

0 commit comments

Comments
 (0)