1
1
import json
2
2
import logging
3
- from typing import Dict , List , Tuple
3
+ from typing import Dict , List , Tuple , Optional
4
4
5
5
from overrides import overrides
6
6
@@ -26,6 +26,14 @@ class SquadReader(DatasetReader):
26
26
``metadata['token_offsets']``. This is so that we can more easily use the official SQuAD
27
27
evaluation script to get metrics.
28
28
29
+ We also support limiting the maximum length for both passage and question. However, some gold
30
+ answer spans may exceed the maximum passage length, which will cause error in making instances.
31
+ We simply skip these spans to avoid errors. If all of the gold answer spans of an example
32
+ are skipped, during training, we will skip this example. During validating or testing, since
33
+ we cannot skip examples, we use the last token as the pseudo gold answer span instead. The
34
+ computed loss will not be accurate as a result. But this will not affect the answer evaluation,
35
+ because we keep all the original gold answer texts.
36
+
29
37
Parameters
30
38
----------
31
39
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``)
@@ -34,14 +42,29 @@ class SquadReader(DatasetReader):
34
42
token_indexers : ``Dict[str, TokenIndexer]``, optional
35
43
We similarly use this for both the question and the passage. See :class:`TokenIndexer`.
36
44
Default is ``{"tokens": SingleIdTokenIndexer()}``.
45
+ lazy : ``bool``, optional (default=False)
46
+ If this is true, ``instances()`` will return an object whose ``__iter__`` method
47
+ reloads the dataset each time it's called. Otherwise, ``instances()`` returns a list.
48
+ passage_length_limit : ``int``, optional (default=None)
49
+ if specified, we will cut the passage if the length of passage exceeds this limit.
50
+ question_length_limit : ``int``, optional (default=None)
51
+ if specified, we will cut the question if the length of passage exceeds this limit.
52
+ skip_invalid_examples: ``bool``, optional (default=False)
53
+ if this is true, we will skip those invalid examples
37
54
"""
38
55
def __init__ (self ,
39
56
tokenizer : Tokenizer = None ,
40
57
token_indexers : Dict [str , TokenIndexer ] = None ,
41
- lazy : bool = False ) -> None :
58
+ lazy : bool = False ,
59
+ passage_length_limit : int = None ,
60
+ question_length_limit : int = None ,
61
+ skip_invalid_examples : bool = False ) -> None :
42
62
super ().__init__ (lazy )
43
63
self ._tokenizer = tokenizer or WordTokenizer ()
44
64
self ._token_indexers = token_indexers or {'tokens' : SingleIdTokenIndexer ()}
65
+ self .passage_length_limit = passage_length_limit
66
+ self .question_length_limit = question_length_limit
67
+ self .skip_invalid_examples = skip_invalid_examples
45
68
46
69
@overrides
47
70
def _read (self , file_path : str ):
@@ -68,25 +91,32 @@ def _read(self, file_path: str):
68
91
zip (span_starts , span_ends ),
69
92
answer_texts ,
70
93
tokenized_paragraph )
71
- yield instance
94
+ if instance is not None :
95
+ yield instance
72
96
73
97
@overrides
74
98
def text_to_instance (self , # type: ignore
75
99
question_text : str ,
76
100
passage_text : str ,
77
101
char_spans : List [Tuple [int , int ]] = None ,
78
102
answer_texts : List [str ] = None ,
79
- passage_tokens : List [Token ] = None ) -> Instance :
103
+ passage_tokens : List [Token ] = None ) -> Optional [ Instance ] :
80
104
# pylint: disable=arguments-differ
81
105
if not passage_tokens :
82
106
passage_tokens = self ._tokenizer .tokenize (passage_text )
107
+ question_tokens = self ._tokenizer .tokenize (question_text )
108
+ if self .passage_length_limit is not None :
109
+ passage_tokens = passage_tokens [: self .passage_length_limit ]
110
+ if self .question_length_limit is not None :
111
+ question_tokens = question_tokens [: self .question_length_limit ]
83
112
char_spans = char_spans or []
84
-
85
113
# We need to convert character indices in `passage_text` to token indices in
86
114
# `passage_tokens`, as the latter is what we'll actually use for supervision.
87
115
token_spans : List [Tuple [int , int ]] = []
88
116
passage_offsets = [(token .idx , token .idx + len (token .text )) for token in passage_tokens ]
89
117
for char_span_start , char_span_end in char_spans :
118
+ if char_span_end > passage_offsets [- 1 ][1 ]:
119
+ continue
90
120
(span_start , span_end ), error = util .char_span_to_token_span (passage_offsets ,
91
121
(char_span_start , char_span_end ))
92
122
if error :
@@ -98,8 +128,13 @@ def text_to_instance(self, # type: ignore
98
128
logger .debug ("Tokens in answer: %s" , passage_tokens [span_start :span_end + 1 ])
99
129
logger .debug ("Answer: %s" , passage_text [char_span_start :char_span_end ])
100
130
token_spans .append ((span_start , span_end ))
101
-
102
- return util .make_reading_comprehension_instance (self ._tokenizer .tokenize (question_text ),
131
+ # The original answer is filtered out
132
+ if char_spans and not token_spans :
133
+ if self .skip_invalid_examples :
134
+ return None
135
+ else :
136
+ token_spans .append ((len (passage_tokens ) - 1 , len (passage_tokens ) - 1 ))
137
+ return util .make_reading_comprehension_instance (question_tokens ,
103
138
passage_tokens ,
104
139
self ._token_indexers ,
105
140
passage_text ,
0 commit comments