Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 08a8c5e

Browse files
yizhongwmatt-gardner
authored andcommitted
Add QaNet model (#2446)
* Add max length limit for passage and question in SQuAD reader. * Add QaNet model. * fixes the squad reader and adds doc. * Move `get_best_span()` function out of bidaf. * Update the docstring of QANet and BiDAF * Move `ResidualWithLayerDropout` to a separate module file. * Update the docstring and test cases for the length limits in squad reader. * Keep the old `get_best_span` function in `bidaf.py`. * Add docstring for `get_best_span` function. * Separate test case for the `get_best_span` function. * Fixes docs. * Update the training configuration file. * ignores pylint error. * add docs for layer dropout. * fixes docs. * Remove the unsqueeze()
1 parent e417486 commit 08a8c5e

File tree

22 files changed

+1206
-15
lines changed

22 files changed

+1206
-15
lines changed

allennlp/data/dataset_readers/reading_comprehension/squad.py

+42-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import logging
3-
from typing import Dict, List, Tuple
3+
from typing import Dict, List, Tuple, Optional
44

55
from overrides import overrides
66

@@ -26,6 +26,14 @@ class SquadReader(DatasetReader):
2626
``metadata['token_offsets']``. This is so that we can more easily use the official SQuAD
2727
evaluation script to get metrics.
2828
29+
We also support limiting the maximum length for both passage and question. However, some gold
30+
answer spans may exceed the maximum passage length, which will cause error in making instances.
31+
We simply skip these spans to avoid errors. If all of the gold answer spans of an example
32+
are skipped, during training, we will skip this example. During validating or testing, since
33+
we cannot skip examples, we use the last token as the pseudo gold answer span instead. The
34+
computed loss will not be accurate as a result. But this will not affect the answer evaluation,
35+
because we keep all the original gold answer texts.
36+
2937
Parameters
3038
----------
3139
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``)
@@ -34,14 +42,29 @@ class SquadReader(DatasetReader):
3442
token_indexers : ``Dict[str, TokenIndexer]``, optional
3543
We similarly use this for both the question and the passage. See :class:`TokenIndexer`.
3644
Default is ``{"tokens": SingleIdTokenIndexer()}``.
45+
lazy : ``bool``, optional (default=False)
46+
If this is true, ``instances()`` will return an object whose ``__iter__`` method
47+
reloads the dataset each time it's called. Otherwise, ``instances()`` returns a list.
48+
passage_length_limit : ``int``, optional (default=None)
49+
if specified, we will cut the passage if the length of passage exceeds this limit.
50+
question_length_limit : ``int``, optional (default=None)
51+
if specified, we will cut the question if the length of passage exceeds this limit.
52+
skip_invalid_examples: ``bool``, optional (default=False)
53+
if this is true, we will skip those invalid examples
3754
"""
3855
def __init__(self,
3956
tokenizer: Tokenizer = None,
4057
token_indexers: Dict[str, TokenIndexer] = None,
41-
lazy: bool = False) -> None:
58+
lazy: bool = False,
59+
passage_length_limit: int = None,
60+
question_length_limit: int = None,
61+
skip_invalid_examples: bool = False) -> None:
4262
super().__init__(lazy)
4363
self._tokenizer = tokenizer or WordTokenizer()
4464
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
65+
self.passage_length_limit = passage_length_limit
66+
self.question_length_limit = question_length_limit
67+
self.skip_invalid_examples = skip_invalid_examples
4568

4669
@overrides
4770
def _read(self, file_path: str):
@@ -68,25 +91,32 @@ def _read(self, file_path: str):
6891
zip(span_starts, span_ends),
6992
answer_texts,
7093
tokenized_paragraph)
71-
yield instance
94+
if instance is not None:
95+
yield instance
7296

7397
@overrides
7498
def text_to_instance(self, # type: ignore
7599
question_text: str,
76100
passage_text: str,
77101
char_spans: List[Tuple[int, int]] = None,
78102
answer_texts: List[str] = None,
79-
passage_tokens: List[Token] = None) -> Instance:
103+
passage_tokens: List[Token] = None) -> Optional[Instance]:
80104
# pylint: disable=arguments-differ
81105
if not passage_tokens:
82106
passage_tokens = self._tokenizer.tokenize(passage_text)
107+
question_tokens = self._tokenizer.tokenize(question_text)
108+
if self.passage_length_limit is not None:
109+
passage_tokens = passage_tokens[: self.passage_length_limit]
110+
if self.question_length_limit is not None:
111+
question_tokens = question_tokens[: self.question_length_limit]
83112
char_spans = char_spans or []
84-
85113
# We need to convert character indices in `passage_text` to token indices in
86114
# `passage_tokens`, as the latter is what we'll actually use for supervision.
87115
token_spans: List[Tuple[int, int]] = []
88116
passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
89117
for char_span_start, char_span_end in char_spans:
118+
if char_span_end > passage_offsets[-1][1]:
119+
continue
90120
(span_start, span_end), error = util.char_span_to_token_span(passage_offsets,
91121
(char_span_start, char_span_end))
92122
if error:
@@ -98,8 +128,13 @@ def text_to_instance(self, # type: ignore
98128
logger.debug("Tokens in answer: %s", passage_tokens[span_start:span_end + 1])
99129
logger.debug("Answer: %s", passage_text[char_span_start:char_span_end])
100130
token_spans.append((span_start, span_end))
101-
102-
return util.make_reading_comprehension_instance(self._tokenizer.tokenize(question_text),
131+
# The original answer is filtered out
132+
if char_spans and not token_spans:
133+
if self.skip_invalid_examples:
134+
return None
135+
else:
136+
token_spans.append((len(passage_tokens) - 1, len(passage_tokens) - 1))
137+
return util.make_reading_comprehension_instance(question_tokens,
103138
passage_tokens,
104139
self._token_indexers,
105140
passage_text,

allennlp/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from allennlp.models.event2mind import Event2Mind
1515
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
1616
from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow
17+
from allennlp.models.reading_comprehension.qanet import QaNet
1718
from allennlp.models.semantic_parsing.nlvr.nlvr_coverage_semantic_parser import NlvrCoverageSemanticParser
1819
from allennlp.models.semantic_parsing.nlvr.nlvr_direct_semantic_parser import NlvrDirectSemanticParser
1920
from allennlp.models.semantic_parsing.quarel.quarel_semantic_parser import QuarelSemanticParser

allennlp/models/reading_comprehension/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow
99
from allennlp.models.reading_comprehension.bidaf_ensemble import BidafEnsemble
1010
from allennlp.models.reading_comprehension.dialog_qa import DialogQA
11+
from allennlp.models.reading_comprehension.qanet import QaNet

allennlp/models/reading_comprehension/bidaf.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from allennlp.common.checks import check_dimensions_match
88
from allennlp.data import Vocabulary
99
from allennlp.models.model import Model
10+
from allennlp.models.reading_comprehension.util import get_best_span
1011
from allennlp.modules import Highway
1112
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder
1213
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
@@ -139,12 +140,11 @@ def forward(self, # type: ignore
139140
ending position of the answer with the passage. This is an `inclusive` token index.
140141
If this is given, we will compute a loss that gets included in the output dictionary.
141142
metadata : ``List[Dict[str, Any]]``, optional
142-
If present, this should contain the question ID, original passage text, and token
143-
offsets into the passage for each instance in the batch. We use this for computing
144-
official metrics using the official SQuAD evaluation script. The length of this list
145-
should be the batch size, and each dictionary should have the keys ``id``,
146-
``original_passage``, and ``token_offsets``. If you only want the best span string and
147-
don't care about official metrics, you can omit the ``id`` key.
143+
metadata : ``List[Dict[str, Any]]``, optional
144+
If present, this should contain the question tokens, passage tokens, original passage
145+
text, and token offsets into the passage for each instance in the batch. The length
146+
of this list should be the batch size, and each dictionary should have the keys
147+
``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.
148148
149149
Returns
150150
-------
@@ -245,7 +245,7 @@ def forward(self, # type: ignore
245245
span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
246246
span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
247247
span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
248-
best_span = self.get_best_span(span_start_logits, span_end_logits)
248+
best_span = get_best_span(span_start_logits, span_end_logits)
249249

250250
output_dict = {
251251
"passage_question_attention": passage_question_attention,

allennlp/models/reading_comprehension/bidaf_ensemble.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from allennlp.models.archival import load_archive
99
from allennlp.models.model import Model
1010
from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow
11+
from allennlp.models.reading_comprehension.util import get_best_span
1112
from allennlp.common import Params
1213
from allennlp.data import Vocabulary
1314
from allennlp.training.metrics import SquadEmAndF1
@@ -140,4 +141,4 @@ def ensemble(subresults: List[Dict[str, torch.Tensor]]) -> torch.Tensor:
140141

141142
span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
142143
span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
143-
return BidirectionalAttentionFlow.get_best_span(span_start_probs.log(), span_end_probs.log()) # type: ignore
144+
return get_best_span(span_start_probs.log(), span_end_probs.log()) # type: ignore

0 commit comments

Comments
 (0)