Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit c728951

Browse files
authored
Integrate new table context in variable free world (#1832)
* All of Shikhar's commits, squashed into one * added column types to type declaration * made minimal changes to table question context to pass lint, mypy and doc build * misc changes missing from previous commits * undo changes made to table question knowledge graph * fixed misc issues from rebase * misc fixes to tests for new context * addressed PR comments
1 parent 3de6943 commit c728951

20 files changed

+1621
-173
lines changed
+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from allennlp.semparse.contexts.table_question_knowledge_graph import TableQuestionKnowledgeGraph
22
from allennlp.semparse.contexts.atis_sql_table_context import AtisSqlTableContext
3+
from allennlp.semparse.contexts.table_question_context import TableQuestionContext
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
import re
2+
import csv
3+
from typing import Dict, List, Set, Tuple, Union
4+
5+
from unidecode import unidecode
6+
from allennlp.data.tokenizers import Token
7+
8+
# == stop words that will be omitted by ContextGenerator
9+
STOP_WORDS = {"", "", "all", "being", "-", "over", "through", "yourselves", "its", "before",
10+
"hadn", "with", "had", ",", "should", "to", "only", "under", "ours", "has", "ought", "do",
11+
"them", "his", "than", "very", "cannot", "they", "not", "during", "yourself", "him",
12+
"nor", "did", "didn", "'ve", "this", "she", "each", "where", "because", "doing", "some", "we", "are",
13+
"further", "ourselves", "out", "what", "for", "weren", "does", "above", "between", "mustn", "?",
14+
"be", "hasn", "who", "were", "here", "shouldn", "let", "hers", "by", "both", "about", "couldn",
15+
"of", "could", "against", "isn", "or", "own", "into", "while", "whom", "down", "wasn", "your",
16+
"from", "her", "their", "aren", "there", "been", ".", "few", "too", "wouldn", "themselves",
17+
":", "was", "until", "more", "himself", "on", "but", "don", "herself", "haven", "those", "he",
18+
"me", "myself", "these", "up", ";", "below", "'re", "can", "theirs", "my", "and", "would", "then",
19+
"is", "am", "it", "doesn", "an", "as", "itself", "at", "have", "in", "any", "if", "!",
20+
"again", "'ll", "no", "that", "when", "same", "how", "other", "which", "you", "many", "shan",
21+
"'t", "'s", "our", "after", "most", "'d", "such", "'m", "why", "a", "off", "i", "yours", "so",
22+
"the", "having", "once"}
23+
24+
NUMBER_CHARACTERS = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-'}
25+
MONTH_NUMBERS = {
26+
'january': 1,
27+
'jan': 1,
28+
'february': 2,
29+
'feb': 2,
30+
'march': 3,
31+
'mar': 3,
32+
'april': 4,
33+
'apr': 4,
34+
'may': 5,
35+
'june': 6,
36+
'jun': 6,
37+
'july': 7,
38+
'jul': 7,
39+
'august': 8,
40+
'aug': 8,
41+
'september': 9,
42+
'sep': 9,
43+
'october': 10,
44+
'oct': 10,
45+
'november': 11,
46+
'nov': 11,
47+
'december': 12,
48+
'dec': 12,
49+
}
50+
ORDER_OF_MAGNITUDE_WORDS = {'hundred': 100, 'thousand': 1000, 'million': 1000000}
51+
NUMBER_WORDS = {
52+
'zero': 0,
53+
'one': 1,
54+
'two': 2,
55+
'three': 3,
56+
'four': 4,
57+
'five': 5,
58+
'six': 6,
59+
'seven': 7,
60+
'eight': 8,
61+
'nine': 9,
62+
'ten': 10,
63+
'first': 1,
64+
'second': 2,
65+
'third': 3,
66+
'fourth': 4,
67+
'fifth': 5,
68+
'sixth': 6,
69+
'seventh': 7,
70+
'eighth': 8,
71+
'ninth': 9,
72+
'tenth': 10,
73+
**MONTH_NUMBERS,
74+
}
75+
76+
77+
class TableQuestionContext:
78+
"""
79+
A barebones implementation similar to
80+
https://github.com/crazydonkey200/neural-symbolic-machines/blob/master/table/wtq/preprocess.py
81+
for extracting entities from a question given a table and type its columns with <string> | <date> | <number>
82+
"""
83+
def __init__(self,
84+
cell_values: Set[str],
85+
column_type_statistics: List[Dict[str, int]],
86+
column_index_to_name: Dict[int, str],
87+
question_tokens: List[Token]) -> None:
88+
self.cell_values = cell_values
89+
self.column_types = {column_index_to_name[column_index]: max(column_type_statistics[column_index],
90+
key=column_type_statistics[column_index].get)
91+
for column_index in column_index_to_name}
92+
self.question_tokens = question_tokens
93+
94+
MAX_TOKENS_FOR_NUM_CELL = 1
95+
96+
@classmethod
97+
def read_from_lines(cls,
98+
lines: List[List[str]],
99+
question_tokens: List[Token]) -> 'TableQuestionContext':
100+
column_index_to_name = {}
101+
102+
header = lines[0] # the first line is the header
103+
index = 1
104+
while lines[index][0] == '-1':
105+
# column names start with fb:row.row.
106+
current_line = lines[index]
107+
column_name_sempre = current_line[2]
108+
column_index = int(current_line[1])
109+
column_name = column_name_sempre.replace('fb:row.row.', '')
110+
column_index_to_name[column_index] = column_name
111+
index += 1
112+
column_node_type_info = [{'string' : 0, 'number' : 0, 'date' : 0}
113+
for col in column_index_to_name]
114+
cell_values = set()
115+
while index < len(lines):
116+
curr_line = lines[index]
117+
column_index = int(curr_line[1])
118+
node_info = dict(zip(header, curr_line))
119+
cell_values.add(cls._normalize_string(node_info['content']))
120+
num_tokens = len(node_info['tokens'].split('|'))
121+
if node_info['date']:
122+
column_node_type_info[column_index]['date'] += 1
123+
# If cell contains too many tokens, then likely not number
124+
elif node_info['number'] and num_tokens <= cls.MAX_TOKENS_FOR_NUM_CELL:
125+
column_node_type_info[column_index]['number'] += 1
126+
elif node_info['content'] != '—':
127+
column_node_type_info[column_index]['string'] += 1
128+
index += 1
129+
return cls(cell_values, column_node_type_info, column_index_to_name, question_tokens)
130+
131+
@classmethod
132+
def read_from_file(cls, filename: str, question_tokens: List[Token]) -> 'TableQuestionContext':
133+
with open(filename, 'r') as file_pointer:
134+
reader = csv.reader(file_pointer, delimiter='\t', quoting=csv.QUOTE_NONE)
135+
# obtain column information
136+
lines = [line for line in reader]
137+
return cls.read_from_lines(lines, question_tokens)
138+
139+
def get_entities_from_question(self):
140+
entity_data = []
141+
for i, token in enumerate(self.question_tokens):
142+
token_text = token.text
143+
if token_text in STOP_WORDS:
144+
continue
145+
normalized_token_text = self._normalize_string(token_text)
146+
if not normalized_token_text:
147+
continue
148+
if self._string_in_table(normalized_token_text):
149+
curr_data = {'value' : normalized_token_text, 'token_start' : i, 'token_end' : i+1}
150+
entity_data.append(curr_data)
151+
152+
extracted_numbers = self._get_numbers_from_tokens(self.question_tokens)
153+
# filter out number entities to avoid repitition
154+
if extracted_numbers:
155+
_, number_token_indices = list(zip(*extracted_numbers))
156+
number_token_text = [self.question_tokens[i].text for i in number_token_indices]
157+
expanded_string_entities = []
158+
for ent in entity_data:
159+
if ent['value'] not in number_token_text:
160+
expanded_string_entities.append(ent)
161+
expanded_entities = [ent['value'] for ent in
162+
self._expand_entities(self.question_tokens, expanded_string_entities)]
163+
else:
164+
expanded_entities = [ent['value'] for ent in
165+
self._expand_entities(self.question_tokens, entity_data)]
166+
return expanded_entities, extracted_numbers #TODO(shikhar) Handle conjunctions
167+
168+
169+
@staticmethod
170+
def _get_numbers_from_tokens(tokens: List[Token]) -> List[Tuple[str, int]]:
171+
"""
172+
Finds numbers in the input tokens and returns them as strings. We do some simple heuristic
173+
number recognition, finding ordinals and cardinals expressed as text ("one", "first",
174+
etc.), as well as numerals ("7th", "3rd"), months (mapping "july" to 7), and units
175+
("1ghz").
176+
177+
We also handle year ranges expressed as decade or centuries ("1800s" or "1950s"), adding
178+
the endpoints of the range as possible numbers to generate.
179+
180+
We return a list of tuples, where each tuple is the (number_string, token_index) for a
181+
number found in the input tokens.
182+
"""
183+
numbers = []
184+
for i, token in enumerate(tokens):
185+
number: Union[int, float] = None
186+
token_text = token.text
187+
text = token.text.replace(',', '').lower()
188+
if text in NUMBER_WORDS:
189+
number = NUMBER_WORDS[text]
190+
191+
magnitude = 1
192+
if i < len(tokens) - 1:
193+
next_token = tokens[i + 1].text.lower()
194+
if next_token in ORDER_OF_MAGNITUDE_WORDS:
195+
magnitude = ORDER_OF_MAGNITUDE_WORDS[next_token]
196+
token_text += ' ' + tokens[i + 1].text
197+
198+
is_range = False
199+
if len(text) > 1 and text[-1] == 's' and text[-2] == '0':
200+
is_range = True
201+
text = text[:-1]
202+
203+
# We strip out any non-digit characters, to capture things like '7th', or '1ghz'. The
204+
# way we're doing this could lead to false positives for something like '1e2', but
205+
# we'll take that risk. It shouldn't be a big deal.
206+
text = ''.join(text[i] for i, char in enumerate(text) if char in NUMBER_CHARACTERS)
207+
208+
try:
209+
# We'll use a check for float(text) to find numbers, because text.isdigit() doesn't
210+
# catch things like "-3" or "0.07".
211+
number = float(text)
212+
except ValueError:
213+
pass
214+
215+
if number is not None:
216+
number = number * magnitude
217+
if '.' in text:
218+
number_string = '%.3f' % number
219+
else:
220+
number_string = '%d' % number
221+
numbers.append((number_string, i))
222+
if is_range:
223+
# TODO(mattg): both numbers in the range will have the same text, and so the
224+
# linking score won't have any way to differentiate them... We should figure
225+
# out a better way to handle this.
226+
num_zeros = 1
227+
while text[-(num_zeros + 1)] == '0':
228+
num_zeros += 1
229+
numbers.append((str(int(number + 10 ** num_zeros)), i))
230+
return numbers
231+
232+
def _string_in_table(self, candidate: str) -> bool:
233+
for cell_value in self.cell_values:
234+
if candidate in cell_value:
235+
return True
236+
return False
237+
238+
def _process_conjunction(self, entity_data):
239+
raise NotImplementedError
240+
241+
def _expand_entities(self, question, entity_data):
242+
new_entities = []
243+
for entity in entity_data:
244+
# to ensure the same strings are not used over and over
245+
if new_entities and entity['token_end'] <= new_entities[-1]['token_end']:
246+
continue
247+
current_start = entity['token_start']
248+
current_end = entity['token_end']
249+
current_token = entity['value']
250+
251+
while current_end < len(question):
252+
next_token = question[current_end].text
253+
next_token_normalized = self._normalize_string(next_token)
254+
if next_token_normalized == "":
255+
current_end += 1
256+
continue
257+
candidate = "%s_%s" %(current_token, next_token_normalized)
258+
if self._string_in_table(candidate):
259+
current_end += 1
260+
current_token = candidate
261+
else:
262+
break
263+
264+
new_entities.append({'token_start' : current_start,
265+
'token_end' : current_end,
266+
'value' : current_token})
267+
return new_entities
268+
269+
@staticmethod
270+
def _normalize_string(string: str) -> str:
271+
"""
272+
These are the transformation rules used to normalize cell in column names in Sempre. See
273+
``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and
274+
``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``. We reproduce those
275+
rules here to normalize and canonicalize cells and columns in the same way so that we can
276+
match them against constants in logical forms appropriately.
277+
"""
278+
# Normalization rules from Sempre
279+
# \u201A -> ,
280+
string = re.sub("‚", ",", string)
281+
string = re.sub("„", ",,", string)
282+
string = re.sub("[·・]", ".", string)
283+
string = re.sub("…", "...", string)
284+
string = re.sub("ˆ", "^", string)
285+
string = re.sub("˜", "~", string)
286+
string = re.sub("‹", "<", string)
287+
string = re.sub("›", ">", string)
288+
string = re.sub("[‘’´`]", "'", string)
289+
string = re.sub("[“”«»]", "\"", string)
290+
string = re.sub("[•†‡²³]", "", string)
291+
string = re.sub("[‐‑–—−]", "-", string)
292+
# Oddly, some unicode characters get converted to _ instead of being stripped. Not really
293+
# sure how sempre decides what to do with these... TODO(mattg): can we just get rid of the
294+
# need for this function somehow? It's causing a whole lot of headaches.
295+
string = re.sub("[ðø′″€⁄ªΣ]", "_", string)
296+
# This is such a mess. There isn't just a block of unicode that we can strip out, because
297+
# sometimes sempre just strips diacritics... We'll try stripping out a few separate
298+
# blocks, skipping the ones that sempre skips...
299+
string = re.sub("[\\u0180-\\u0210]", "", string).strip()
300+
string = re.sub("[\\u0220-\\uFFFF]", "", string).strip()
301+
string = string.replace("\\n", "_")
302+
string = re.sub("\\s+", " ", string)
303+
# Canonicalization rules from Sempre
304+
string = re.sub("[^\\w]", "_", string)
305+
string = re.sub("_+", "_", string)
306+
string = re.sub("_$", "", string)
307+
return unidecode(string.lower())

allennlp/semparse/contexts/table_question_knowledge_graph.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
**MONTH_NUMBERS,
6262
}
6363

64+
6465
class TableQuestionKnowledgeGraph(KnowledgeGraph):
6566
"""
6667
A ``TableQuestionKnowledgeGraph`` represents the linkable entities in a table and a question

allennlp/semparse/type_declarations/type_declaration.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class NamedBasicType(BasicType):
109109
110110
Parameters
111111
----------
112-
string_rep : str
112+
string_rep : ``str``
113113
String representation of the type.
114114
"""
115115
def __init__(self, string_rep) -> None:
@@ -127,6 +127,30 @@ def str(self):
127127
return self._string_rep
128128

129129

130+
class MultiMatchNamedBasicType(NamedBasicType):
131+
"""
132+
A ``NamedBasicType`` that matches with any type within a list of ``BasicTypes`` that it takes
133+
as an additional argument during instantiation. We just override the ``matches`` method in
134+
``BasicType`` to match against any of the types given by the list.
135+
136+
Parameters
137+
----------
138+
string_rep : ``str``
139+
String representation of the type, passed to super class.
140+
types_to_match : ``List[BasicType]``
141+
List of types that this type should match with.
142+
"""
143+
def __init__(self,
144+
string_rep,
145+
types_to_match: List[BasicType]) -> None:
146+
super().__init__(string_rep)
147+
self._types_to_match = set(types_to_match)
148+
149+
@overrides
150+
def matches(self, other):
151+
return super().matches(other) or other in self._types_to_match
152+
153+
130154
class PlaceholderType(ComplexType):
131155
"""
132156
``PlaceholderType`` is a ``ComplexType`` that involves placeholders, and thus its type

0 commit comments

Comments
 (0)