7
7
import string
8
8
from typing import Any , Dict , List , Tuple
9
9
10
- from allennlp .data .fields import Field , TextField , IndexField , MetadataField
10
+ from allennlp .data .fields import Field , TextField , IndexField , \
11
+ MetadataField , LabelField , ListField , SequenceLabelField
11
12
from allennlp .data .instance import Instance
12
13
from allennlp .data .token_indexers import TokenIndexer
13
14
from allennlp .data .tokenizers import Token
19
20
IGNORED_TOKENS = {'a' , 'an' , 'the' }
20
21
STRIPPED_CHARACTERS = string .punctuation + '' .join ([u"‘" , u"’" , u"´" , u"`" , "_" ])
21
22
23
+
22
24
def normalize_text (text : str ) -> str :
23
25
"""
24
26
Performs a normalization that is very similar to that done by the normalization functions in
@@ -187,12 +189,9 @@ def make_reading_comprehension_instance(question_tokens: List[Token],
187
189
passage_field = TextField (passage_tokens , token_indexers )
188
190
fields ['passage' ] = passage_field
189
191
fields ['question' ] = TextField (question_tokens , token_indexers )
190
- metadata = {
191
- 'original_passage' : passage_text ,
192
- 'token_offsets' : passage_offsets ,
193
- 'question_tokens' : [token .text for token in question_tokens ],
194
- 'passage_tokens' : [token .text for token in passage_tokens ],
195
- }
192
+ metadata = {'original_passage' : passage_text , 'token_offsets' : passage_offsets ,
193
+ 'question_tokens' : [token .text for token in question_tokens ],
194
+ 'passage_tokens' : [token .text for token in passage_tokens ], }
196
195
if answer_texts :
197
196
metadata ['answer_texts' ] = answer_texts
198
197
@@ -213,3 +212,160 @@ def make_reading_comprehension_instance(question_tokens: List[Token],
213
212
metadata .update (additional_metadata )
214
213
fields ['metadata' ] = MetadataField (metadata )
215
214
return Instance (fields )
215
+
216
+
217
+ def make_reading_comprehension_instance_quac (question_list_tokens : List [List [Token ]],
218
+ passage_tokens : List [Token ],
219
+ token_indexers : Dict [str , TokenIndexer ],
220
+ passage_text : str ,
221
+ token_span_lists : List [List [Tuple [int , int ]]] = None ,
222
+ yesno_list : List [int ] = None ,
223
+ followup_list : List [int ] = None ,
224
+ additional_metadata : Dict [str , Any ] = None ,
225
+ num_context_answers : int = 0 ) -> Instance :
226
+ """
227
+ Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use
228
+ in a reading comprehension model.
229
+
230
+ Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both
231
+ ``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts``
232
+ and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end``
233
+ fields, which are both ``IndexFields``.
234
+
235
+ Parameters
236
+ ----------
237
+ question_list_tokens : ``List[List[Token]]``
238
+ An already-tokenized list of questions. Each dialog have multiple questions.
239
+ passage_tokens : ``List[Token]``
240
+ An already-tokenized passage that contains the answer to the given question.
241
+ token_indexers : ``Dict[str, TokenIndexer]``
242
+ Determines how the question and passage ``TextFields`` will be converted into tensors that
243
+ get input to a model. See :class:`TokenIndexer`.
244
+ passage_text : ``str``
245
+ The original passage text. We need this so that we can recover the actual span from the
246
+ original passage that the model predicts as the answer to the question. This is used in
247
+ official evaluation scripts.
248
+ token_spans_lists : ``List[List[Tuple[int, int]]]``, optional
249
+ Indices into ``passage_tokens`` to use as the answer to the question for training. This is
250
+ a list of list, first because there is multiple questions per dialog, and
251
+ because there might be several possible correct answer spans in the passage.
252
+ Currently, we just select the last span in this list (i.e., QuAC has multiple
253
+ annotations on the dev set; this will select the last span, which was given by the original annotator).
254
+ yesno_list : ``List[int]``
255
+ List of the affirmation bit for each question answer pairs.
256
+ followup_list : ``List[int]``
257
+ List of the continuation bit for each question answer pairs.
258
+ num_context_answers : ``int``, optional
259
+ How many answers to encode into the passage.
260
+ additional_metadata : ``Dict[str, Any]``, optional
261
+ The constructed ``metadata`` field will by default contain ``original_passage``,
262
+ ``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If
263
+ you want any other metadata to be associated with each instance, you can pass that in here.
264
+ This dictionary will get added to the ``metadata`` dictionary we already construct.
265
+ """
266
+ additional_metadata = additional_metadata or {}
267
+ fields : Dict [str , Field ] = {}
268
+ passage_offsets = [(token .idx , token .idx + len (token .text )) for token in passage_tokens ]
269
+ # This is separate so we can reference it later with a known type.
270
+ passage_field = TextField (passage_tokens , token_indexers )
271
+ fields ['passage' ] = passage_field
272
+ fields ['question' ] = ListField ([TextField (q_tokens , token_indexers ) for q_tokens in question_list_tokens ])
273
+ metadata = {'original_passage' : passage_text ,
274
+ 'token_offsets' : passage_offsets ,
275
+ 'question_tokens' : [[token .text for token in question_tokens ] \
276
+ for question_tokens in question_list_tokens ],
277
+ 'passage_tokens' : [token .text for token in passage_tokens ], }
278
+ p1_answer_marker_list : List [Field ] = []
279
+ p2_answer_marker_list : List [Field ] = []
280
+ p3_answer_marker_list : List [Field ] = []
281
+
282
+ def get_tag (i , i_name ):
283
+ # Generate a tag to mark previous answer span in the passage.
284
+ return "<{0:d}_{1:s}>" .format (i , i_name )
285
+
286
+ def mark_tag (span_start , span_end , passage_tags , prev_answer_distance ):
287
+ try :
288
+ assert span_start > 0
289
+ assert span_end > 0
290
+ except :
291
+ raise ValueError ("Previous {0:d}th answer span should have been updated!" .format (prev_answer_distance ))
292
+ # Modify "tags" to mark previous answer span.
293
+ if span_start == span_end :
294
+ passage_tags [prev_answer_distance ][span_start ] = get_tag (prev_answer_distance , "" )
295
+ else :
296
+ passage_tags [prev_answer_distance ][span_start ] = get_tag (prev_answer_distance , "start" )
297
+ passage_tags [prev_answer_distance ][span_end ] = get_tag (prev_answer_distance , "end" )
298
+ for passage_index in range (span_start + 1 , span_end ):
299
+ passage_tags [prev_answer_distance ][passage_index ] = get_tag (prev_answer_distance , "in" )
300
+
301
+ if token_span_lists :
302
+ span_start_list : List [Field ] = []
303
+ span_end_list : List [Field ] = []
304
+ p1_span_start , p1_span_end , p2_span_start = - 1 , - 1 , - 1
305
+ p2_span_end , p3_span_start , p3_span_end = - 1 , - 1 , - 1
306
+ # Looping each <<answers>>.
307
+ for question_index , answer_span_lists in enumerate (token_span_lists ):
308
+ span_start , span_end = answer_span_lists [- 1 ] # Last one is the original answer
309
+ span_start_list .append (IndexField (span_start , passage_field ))
310
+ span_end_list .append (IndexField (span_end , passage_field ))
311
+ prev_answer_marker_lists = [["O" ] * len (passage_tokens ), ["O" ] * len (passage_tokens ),
312
+ ["O" ] * len (passage_tokens ), ["O" ] * len (passage_tokens )]
313
+ if question_index > 0 and num_context_answers > 0 :
314
+ mark_tag (p1_span_start , p1_span_end , prev_answer_marker_lists , 1 )
315
+ if question_index > 1 and num_context_answers > 1 :
316
+ mark_tag (p2_span_start , p2_span_end , prev_answer_marker_lists , 2 )
317
+ if question_index > 2 and num_context_answers > 2 :
318
+ mark_tag (p3_span_start , p3_span_end , prev_answer_marker_lists , 3 )
319
+ p3_span_start = p2_span_start
320
+ p3_span_end = p2_span_end
321
+ p2_span_start = p1_span_start
322
+ p2_span_end = p1_span_end
323
+ p1_span_start = span_start
324
+ p1_span_end = span_end
325
+ if num_context_answers > 2 :
326
+ p3_answer_marker_list .append (SequenceLabelField (prev_answer_marker_lists [3 ],
327
+ passage_field ,
328
+ label_namespace = "answer_tags" ))
329
+ if num_context_answers > 1 :
330
+ p2_answer_marker_list .append (SequenceLabelField (prev_answer_marker_lists [2 ],
331
+ passage_field ,
332
+ label_namespace = "answer_tags" ))
333
+ if num_context_answers > 0 :
334
+ p1_answer_marker_list .append (SequenceLabelField (prev_answer_marker_lists [1 ],
335
+ passage_field ,
336
+ label_namespace = "answer_tags" ))
337
+ fields ['span_start' ] = ListField (span_start_list )
338
+ fields ['span_end' ] = ListField (span_end_list )
339
+ if num_context_answers > 0 :
340
+ fields ['p1_answer_marker' ] = ListField (p1_answer_marker_list )
341
+ if num_context_answers > 1 :
342
+ fields ['p2_answer_marker' ] = ListField (p2_answer_marker_list )
343
+ if num_context_answers > 2 :
344
+ fields ['p3_answer_marker' ] = ListField (p3_answer_marker_list )
345
+ fields ['yesno_list' ] = ListField ( \
346
+ [LabelField (yesno , label_namespace = "yesno_labels" ) for yesno in yesno_list ])
347
+ fields ['followup_list' ] = ListField ([LabelField (followup , label_namespace = "followup_labels" ) \
348
+ for followup in followup_list ])
349
+ metadata .update (additional_metadata )
350
+ fields ['metadata' ] = MetadataField (metadata )
351
+ return Instance (fields )
352
+
353
+
354
+ def handle_cannot (reference_answers : List [str ]):
355
+ """
356
+ Process a list of reference answers.
357
+ If equal or more than half of the reference answers are "CANNOTANSWER", take it as gold.
358
+ Otherwise, return answers that are not "CANNOTANSWER".
359
+ """
360
+ num_cannot = 0
361
+ num_spans = 0
362
+ for ref in reference_answers :
363
+ if ref == 'CANNOTANSWER' :
364
+ num_cannot += 1
365
+ else :
366
+ num_spans += 1
367
+ if num_cannot >= num_spans :
368
+ reference_answers = ['CANNOTANSWER' ]
369
+ else :
370
+ reference_answers = [x for x in reference_answers if x != 'CANNOTANSWER' ]
371
+ return reference_answers
0 commit comments