|
| 1 | +import re |
| 2 | +import csv |
| 3 | +from typing import Dict, List, Set, Tuple, Union |
| 4 | + |
| 5 | +from unidecode import unidecode |
| 6 | +from allennlp.data.tokenizers import Token |
| 7 | + |
| 8 | +# == stop words that will be omitted by ContextGenerator |
| 9 | +STOP_WORDS = {"", "", "all", "being", "-", "over", "through", "yourselves", "its", "before", |
| 10 | + "hadn", "with", "had", ",", "should", "to", "only", "under", "ours", "has", "ought", "do", |
| 11 | + "them", "his", "than", "very", "cannot", "they", "not", "during", "yourself", "him", |
| 12 | + "nor", "did", "didn", "'ve", "this", "she", "each", "where", "because", "doing", "some", "we", "are", |
| 13 | + "further", "ourselves", "out", "what", "for", "weren", "does", "above", "between", "mustn", "?", |
| 14 | + "be", "hasn", "who", "were", "here", "shouldn", "let", "hers", "by", "both", "about", "couldn", |
| 15 | + "of", "could", "against", "isn", "or", "own", "into", "while", "whom", "down", "wasn", "your", |
| 16 | + "from", "her", "their", "aren", "there", "been", ".", "few", "too", "wouldn", "themselves", |
| 17 | + ":", "was", "until", "more", "himself", "on", "but", "don", "herself", "haven", "those", "he", |
| 18 | + "me", "myself", "these", "up", ";", "below", "'re", "can", "theirs", "my", "and", "would", "then", |
| 19 | + "is", "am", "it", "doesn", "an", "as", "itself", "at", "have", "in", "any", "if", "!", |
| 20 | + "again", "'ll", "no", "that", "when", "same", "how", "other", "which", "you", "many", "shan", |
| 21 | + "'t", "'s", "our", "after", "most", "'d", "such", "'m", "why", "a", "off", "i", "yours", "so", |
| 22 | + "the", "having", "once"} |
| 23 | + |
| 24 | +NUMBER_CHARACTERS = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-'} |
| 25 | +MONTH_NUMBERS = { |
| 26 | + 'january': 1, |
| 27 | + 'jan': 1, |
| 28 | + 'february': 2, |
| 29 | + 'feb': 2, |
| 30 | + 'march': 3, |
| 31 | + 'mar': 3, |
| 32 | + 'april': 4, |
| 33 | + 'apr': 4, |
| 34 | + 'may': 5, |
| 35 | + 'june': 6, |
| 36 | + 'jun': 6, |
| 37 | + 'july': 7, |
| 38 | + 'jul': 7, |
| 39 | + 'august': 8, |
| 40 | + 'aug': 8, |
| 41 | + 'september': 9, |
| 42 | + 'sep': 9, |
| 43 | + 'october': 10, |
| 44 | + 'oct': 10, |
| 45 | + 'november': 11, |
| 46 | + 'nov': 11, |
| 47 | + 'december': 12, |
| 48 | + 'dec': 12, |
| 49 | + } |
| 50 | +ORDER_OF_MAGNITUDE_WORDS = {'hundred': 100, 'thousand': 1000, 'million': 1000000} |
| 51 | +NUMBER_WORDS = { |
| 52 | + 'zero': 0, |
| 53 | + 'one': 1, |
| 54 | + 'two': 2, |
| 55 | + 'three': 3, |
| 56 | + 'four': 4, |
| 57 | + 'five': 5, |
| 58 | + 'six': 6, |
| 59 | + 'seven': 7, |
| 60 | + 'eight': 8, |
| 61 | + 'nine': 9, |
| 62 | + 'ten': 10, |
| 63 | + 'first': 1, |
| 64 | + 'second': 2, |
| 65 | + 'third': 3, |
| 66 | + 'fourth': 4, |
| 67 | + 'fifth': 5, |
| 68 | + 'sixth': 6, |
| 69 | + 'seventh': 7, |
| 70 | + 'eighth': 8, |
| 71 | + 'ninth': 9, |
| 72 | + 'tenth': 10, |
| 73 | + **MONTH_NUMBERS, |
| 74 | + } |
| 75 | + |
| 76 | + |
| 77 | +class TableQuestionContext: |
| 78 | + """ |
| 79 | + A barebones implementation similar to |
| 80 | + https://github.com/crazydonkey200/neural-symbolic-machines/blob/master/table/wtq/preprocess.py |
| 81 | + for extracting entities from a question given a table and type its columns with <string> | <date> | <number> |
| 82 | + """ |
| 83 | + def __init__(self, |
| 84 | + cell_values: Set[str], |
| 85 | + column_type_statistics: List[Dict[str, int]], |
| 86 | + column_index_to_name: Dict[int, str], |
| 87 | + question_tokens: List[Token]) -> None: |
| 88 | + self.cell_values = cell_values |
| 89 | + self.column_types = {column_index_to_name[column_index]: max(column_type_statistics[column_index], |
| 90 | + key=column_type_statistics[column_index].get) |
| 91 | + for column_index in column_index_to_name} |
| 92 | + self.question_tokens = question_tokens |
| 93 | + |
| 94 | + MAX_TOKENS_FOR_NUM_CELL = 1 |
| 95 | + |
| 96 | + @classmethod |
| 97 | + def read_from_lines(cls, |
| 98 | + lines: List[List[str]], |
| 99 | + question_tokens: List[Token]) -> 'TableQuestionContext': |
| 100 | + column_index_to_name = {} |
| 101 | + |
| 102 | + header = lines[0] # the first line is the header |
| 103 | + index = 1 |
| 104 | + while lines[index][0] == '-1': |
| 105 | + # column names start with fb:row.row. |
| 106 | + current_line = lines[index] |
| 107 | + column_name_sempre = current_line[2] |
| 108 | + column_index = int(current_line[1]) |
| 109 | + column_name = column_name_sempre.replace('fb:row.row.', '') |
| 110 | + column_index_to_name[column_index] = column_name |
| 111 | + index += 1 |
| 112 | + column_node_type_info = [{'string' : 0, 'number' : 0, 'date' : 0} |
| 113 | + for col in column_index_to_name] |
| 114 | + cell_values = set() |
| 115 | + while index < len(lines): |
| 116 | + curr_line = lines[index] |
| 117 | + column_index = int(curr_line[1]) |
| 118 | + node_info = dict(zip(header, curr_line)) |
| 119 | + cell_values.add(cls._normalize_string(node_info['content'])) |
| 120 | + num_tokens = len(node_info['tokens'].split('|')) |
| 121 | + if node_info['date']: |
| 122 | + column_node_type_info[column_index]['date'] += 1 |
| 123 | + # If cell contains too many tokens, then likely not number |
| 124 | + elif node_info['number'] and num_tokens <= cls.MAX_TOKENS_FOR_NUM_CELL: |
| 125 | + column_node_type_info[column_index]['number'] += 1 |
| 126 | + elif node_info['content'] != '—': |
| 127 | + column_node_type_info[column_index]['string'] += 1 |
| 128 | + index += 1 |
| 129 | + return cls(cell_values, column_node_type_info, column_index_to_name, question_tokens) |
| 130 | + |
| 131 | + @classmethod |
| 132 | + def read_from_file(cls, filename: str, question_tokens: List[Token]) -> 'TableQuestionContext': |
| 133 | + with open(filename, 'r') as file_pointer: |
| 134 | + reader = csv.reader(file_pointer, delimiter='\t', quoting=csv.QUOTE_NONE) |
| 135 | + # obtain column information |
| 136 | + lines = [line for line in reader] |
| 137 | + return cls.read_from_lines(lines, question_tokens) |
| 138 | + |
| 139 | + def get_entities_from_question(self): |
| 140 | + entity_data = [] |
| 141 | + for i, token in enumerate(self.question_tokens): |
| 142 | + token_text = token.text |
| 143 | + if token_text in STOP_WORDS: |
| 144 | + continue |
| 145 | + normalized_token_text = self._normalize_string(token_text) |
| 146 | + if not normalized_token_text: |
| 147 | + continue |
| 148 | + if self._string_in_table(normalized_token_text): |
| 149 | + curr_data = {'value' : normalized_token_text, 'token_start' : i, 'token_end' : i+1} |
| 150 | + entity_data.append(curr_data) |
| 151 | + |
| 152 | + extracted_numbers = self._get_numbers_from_tokens(self.question_tokens) |
| 153 | + # filter out number entities to avoid repitition |
| 154 | + if extracted_numbers: |
| 155 | + _, number_token_indices = list(zip(*extracted_numbers)) |
| 156 | + number_token_text = [self.question_tokens[i].text for i in number_token_indices] |
| 157 | + expanded_string_entities = [] |
| 158 | + for ent in entity_data: |
| 159 | + if ent['value'] not in number_token_text: |
| 160 | + expanded_string_entities.append(ent) |
| 161 | + expanded_entities = [ent['value'] for ent in |
| 162 | + self._expand_entities(self.question_tokens, expanded_string_entities)] |
| 163 | + else: |
| 164 | + expanded_entities = [ent['value'] for ent in |
| 165 | + self._expand_entities(self.question_tokens, entity_data)] |
| 166 | + return expanded_entities, extracted_numbers #TODO(shikhar) Handle conjunctions |
| 167 | + |
| 168 | + |
| 169 | + @staticmethod |
| 170 | + def _get_numbers_from_tokens(tokens: List[Token]) -> List[Tuple[str, int]]: |
| 171 | + """ |
| 172 | + Finds numbers in the input tokens and returns them as strings. We do some simple heuristic |
| 173 | + number recognition, finding ordinals and cardinals expressed as text ("one", "first", |
| 174 | + etc.), as well as numerals ("7th", "3rd"), months (mapping "july" to 7), and units |
| 175 | + ("1ghz"). |
| 176 | +
|
| 177 | + We also handle year ranges expressed as decade or centuries ("1800s" or "1950s"), adding |
| 178 | + the endpoints of the range as possible numbers to generate. |
| 179 | +
|
| 180 | + We return a list of tuples, where each tuple is the (number_string, token_index) for a |
| 181 | + number found in the input tokens. |
| 182 | + """ |
| 183 | + numbers = [] |
| 184 | + for i, token in enumerate(tokens): |
| 185 | + number: Union[int, float] = None |
| 186 | + token_text = token.text |
| 187 | + text = token.text.replace(',', '').lower() |
| 188 | + if text in NUMBER_WORDS: |
| 189 | + number = NUMBER_WORDS[text] |
| 190 | + |
| 191 | + magnitude = 1 |
| 192 | + if i < len(tokens) - 1: |
| 193 | + next_token = tokens[i + 1].text.lower() |
| 194 | + if next_token in ORDER_OF_MAGNITUDE_WORDS: |
| 195 | + magnitude = ORDER_OF_MAGNITUDE_WORDS[next_token] |
| 196 | + token_text += ' ' + tokens[i + 1].text |
| 197 | + |
| 198 | + is_range = False |
| 199 | + if len(text) > 1 and text[-1] == 's' and text[-2] == '0': |
| 200 | + is_range = True |
| 201 | + text = text[:-1] |
| 202 | + |
| 203 | + # We strip out any non-digit characters, to capture things like '7th', or '1ghz'. The |
| 204 | + # way we're doing this could lead to false positives for something like '1e2', but |
| 205 | + # we'll take that risk. It shouldn't be a big deal. |
| 206 | + text = ''.join(text[i] for i, char in enumerate(text) if char in NUMBER_CHARACTERS) |
| 207 | + |
| 208 | + try: |
| 209 | + # We'll use a check for float(text) to find numbers, because text.isdigit() doesn't |
| 210 | + # catch things like "-3" or "0.07". |
| 211 | + number = float(text) |
| 212 | + except ValueError: |
| 213 | + pass |
| 214 | + |
| 215 | + if number is not None: |
| 216 | + number = number * magnitude |
| 217 | + if '.' in text: |
| 218 | + number_string = '%.3f' % number |
| 219 | + else: |
| 220 | + number_string = '%d' % number |
| 221 | + numbers.append((number_string, i)) |
| 222 | + if is_range: |
| 223 | + # TODO(mattg): both numbers in the range will have the same text, and so the |
| 224 | + # linking score won't have any way to differentiate them... We should figure |
| 225 | + # out a better way to handle this. |
| 226 | + num_zeros = 1 |
| 227 | + while text[-(num_zeros + 1)] == '0': |
| 228 | + num_zeros += 1 |
| 229 | + numbers.append((str(int(number + 10 ** num_zeros)), i)) |
| 230 | + return numbers |
| 231 | + |
| 232 | + def _string_in_table(self, candidate: str) -> bool: |
| 233 | + for cell_value in self.cell_values: |
| 234 | + if candidate in cell_value: |
| 235 | + return True |
| 236 | + return False |
| 237 | + |
| 238 | + def _process_conjunction(self, entity_data): |
| 239 | + raise NotImplementedError |
| 240 | + |
| 241 | + def _expand_entities(self, question, entity_data): |
| 242 | + new_entities = [] |
| 243 | + for entity in entity_data: |
| 244 | + # to ensure the same strings are not used over and over |
| 245 | + if new_entities and entity['token_end'] <= new_entities[-1]['token_end']: |
| 246 | + continue |
| 247 | + current_start = entity['token_start'] |
| 248 | + current_end = entity['token_end'] |
| 249 | + current_token = entity['value'] |
| 250 | + |
| 251 | + while current_end < len(question): |
| 252 | + next_token = question[current_end].text |
| 253 | + next_token_normalized = self._normalize_string(next_token) |
| 254 | + if next_token_normalized == "": |
| 255 | + current_end += 1 |
| 256 | + continue |
| 257 | + candidate = "%s_%s" %(current_token, next_token_normalized) |
| 258 | + if self._string_in_table(candidate): |
| 259 | + current_end += 1 |
| 260 | + current_token = candidate |
| 261 | + else: |
| 262 | + break |
| 263 | + |
| 264 | + new_entities.append({'token_start' : current_start, |
| 265 | + 'token_end' : current_end, |
| 266 | + 'value' : current_token}) |
| 267 | + return new_entities |
| 268 | + |
| 269 | + @staticmethod |
| 270 | + def _normalize_string(string: str) -> str: |
| 271 | + """ |
| 272 | + These are the transformation rules used to normalize cell in column names in Sempre. See |
| 273 | + ``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and |
| 274 | + ``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``. We reproduce those |
| 275 | + rules here to normalize and canonicalize cells and columns in the same way so that we can |
| 276 | + match them against constants in logical forms appropriately. |
| 277 | + """ |
| 278 | + # Normalization rules from Sempre |
| 279 | + # \u201A -> , |
| 280 | + string = re.sub("‚", ",", string) |
| 281 | + string = re.sub("„", ",,", string) |
| 282 | + string = re.sub("[·・]", ".", string) |
| 283 | + string = re.sub("…", "...", string) |
| 284 | + string = re.sub("ˆ", "^", string) |
| 285 | + string = re.sub("˜", "~", string) |
| 286 | + string = re.sub("‹", "<", string) |
| 287 | + string = re.sub("›", ">", string) |
| 288 | + string = re.sub("[‘’´`]", "'", string) |
| 289 | + string = re.sub("[“”«»]", "\"", string) |
| 290 | + string = re.sub("[•†‡²³]", "", string) |
| 291 | + string = re.sub("[‐‑–—−]", "-", string) |
| 292 | + # Oddly, some unicode characters get converted to _ instead of being stripped. Not really |
| 293 | + # sure how sempre decides what to do with these... TODO(mattg): can we just get rid of the |
| 294 | + # need for this function somehow? It's causing a whole lot of headaches. |
| 295 | + string = re.sub("[ðø′″€⁄ªΣ]", "_", string) |
| 296 | + # This is such a mess. There isn't just a block of unicode that we can strip out, because |
| 297 | + # sometimes sempre just strips diacritics... We'll try stripping out a few separate |
| 298 | + # blocks, skipping the ones that sempre skips... |
| 299 | + string = re.sub("[\\u0180-\\u0210]", "", string).strip() |
| 300 | + string = re.sub("[\\u0220-\\uFFFF]", "", string).strip() |
| 301 | + string = string.replace("\\n", "_") |
| 302 | + string = re.sub("\\s+", " ", string) |
| 303 | + # Canonicalization rules from Sempre |
| 304 | + string = re.sub("[^\\w]", "_", string) |
| 305 | + string = re.sub("_+", "_", string) |
| 306 | + string = re.sub("_$", "", string) |
| 307 | + return unidecode(string.lower()) |
0 commit comments