|
1 | 1 | from typing import List, Dict, Tuple, Set, Callable
|
2 |
| -from copy import deepcopy |
| 2 | +from copy import copy |
3 | 3 | import numpy
|
4 | 4 | from nltk import ngrams
|
5 | 5 |
|
6 | 6 | from parsimonious.grammar import Grammar
|
| 7 | +from parsimonious.expressions import Expression, OneOf, Sequence, Literal |
7 | 8 |
|
8 | 9 | from allennlp.semparse.contexts.atis_tables import * # pylint: disable=wildcard-import,unused-wildcard-import
|
9 | 10 | from allennlp.semparse.contexts.atis_sql_table_context import AtisSqlTableContext, KEYWORDS
|
10 |
| -from allennlp.semparse.contexts.sql_context_utils import SqlVisitor, format_action |
| 11 | +from allennlp.semparse.contexts.sql_context_utils import SqlVisitor, format_action, initialize_valid_actions |
11 | 12 |
|
12 | 13 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
|
13 | 14 |
|
@@ -67,14 +68,125 @@ def __init__(self,
|
67 | 68 | self.linked_entities = self._get_linked_entities()
|
68 | 69 | self.dates = self._get_dates()
|
69 | 70 |
|
70 |
| - self.valid_actions: Dict[str, List[str]] = self._update_valid_actions() |
71 | 71 | entities, linking_scores = self._flatten_entities()
|
72 | 72 | # This has shape (num_entities, num_utterance_tokens).
|
73 | 73 | self.linking_scores: numpy.ndarray = linking_scores
|
74 | 74 | self.entities: List[str] = entities
|
75 |
| - self.grammar_dictionary = self.update_grammar_dictionary() |
76 |
| - self.grammar_string: str = self.get_grammar_string() |
77 |
| - self.grammar_with_context: Grammar = Grammar(self.grammar_string) |
| 75 | + self.grammar: Grammar = self._update_grammar() |
| 76 | + self.valid_actions = initialize_valid_actions(self.grammar, |
| 77 | + KEYWORDS) |
| 78 | + |
| 79 | + def _update_grammar(self): |
| 80 | + """ |
| 81 | + We create a new ``Grammar`` object from the one in ``AtisSqlTableContext``, that also |
| 82 | + has the new entities that are extracted from the utterance. Stitching together the expressions |
| 83 | + to form the grammar is a little tedious here, but it is worth it because we don't have to create |
| 84 | + a new grammar from scratch. Creating a new grammar is expensive because we have many production |
| 85 | + rules that have all database values in the column on the right hand side. We update the expressions |
| 86 | + bottom up, since the higher level expressions may refer to the lower level ones. For example, the |
| 87 | + ternary expression will refer to the start and end times. |
| 88 | + """ |
| 89 | + |
| 90 | + # This will give us a shallow copy, but that's OK because everything |
| 91 | + # inside is immutable so we get a new copy of it. |
| 92 | + new_grammar = copy(AtisWorld.sql_table_context.grammar) |
| 93 | + |
| 94 | + numbers = self._get_numeric_database_values('number') |
| 95 | + number_literals = [Literal(number) for number in numbers] |
| 96 | + new_grammar['number'] = OneOf(*number_literals, name='number') |
| 97 | + self._update_expression_reference(new_grammar, 'pos_value', 'number') |
| 98 | + |
| 99 | + time_range_start = self._get_numeric_database_values('time_range_start') |
| 100 | + time_range_start_literals = [Literal(time) for time in time_range_start] |
| 101 | + new_grammar['time_range_start'] = OneOf(*time_range_start_literals, name='time_range_start') |
| 102 | + |
| 103 | + time_range_end = self._get_numeric_database_values('time_range_end') |
| 104 | + time_range_end_literals = [Literal(time) for time in time_range_end] |
| 105 | + new_grammar['time_range_end'] = OneOf(*time_range_end_literals, name='time_range_end') |
| 106 | + |
| 107 | + ternary_expressions = [self._get_sequence_with_spacing(new_grammar, |
| 108 | + [new_grammar['col_ref'], |
| 109 | + Literal('BETWEEN'), |
| 110 | + new_grammar['time_range_start'], |
| 111 | + Literal(f'AND'), |
| 112 | + new_grammar['time_range_end']]), |
| 113 | + self._get_sequence_with_spacing(new_grammar, |
| 114 | + [new_grammar['col_ref'], |
| 115 | + Literal('NOT'), |
| 116 | + Literal('BETWEEN'), |
| 117 | + new_grammar['time_range_start'], |
| 118 | + Literal(f'AND'), |
| 119 | + new_grammar['time_range_end']]), |
| 120 | + self._get_sequence_with_spacing(new_grammar, |
| 121 | + [new_grammar['col_ref'], |
| 122 | + Literal('not'), |
| 123 | + Literal('BETWEEN'), |
| 124 | + new_grammar['time_range_start'], |
| 125 | + Literal(f'AND'), |
| 126 | + new_grammar['time_range_end']])] |
| 127 | + |
| 128 | + new_grammar['ternaryexpr'] = OneOf(*ternary_expressions, name='ternaryexpr') |
| 129 | + self._update_expression_reference(new_grammar, 'condition', 'ternaryexpr') |
| 130 | + |
| 131 | + if self.dates: |
| 132 | + new_binary_expressions = [] |
| 133 | + year_binary_expression = self._get_sequence_with_spacing(new_grammar, |
| 134 | + [Literal('date_day'), |
| 135 | + Literal('.'), |
| 136 | + Literal('year'), |
| 137 | + new_grammar['binaryop'], |
| 138 | + Literal(f'{self.dates[0].year}')]) |
| 139 | + new_binary_expressions.append(year_binary_expression) |
| 140 | + for date in self.dates: |
| 141 | + month_binary_expression = self._get_sequence_with_spacing(new_grammar, |
| 142 | + [Literal('date_day'), |
| 143 | + Literal('.'), |
| 144 | + Literal('month_number'), |
| 145 | + new_grammar['binaryop'], |
| 146 | + Literal(f'{date.month}')]) |
| 147 | + |
| 148 | + day_binary_expression = self._get_sequence_with_spacing(new_grammar, |
| 149 | + [Literal('date_day'), |
| 150 | + Literal('.'), |
| 151 | + Literal('day_number'), |
| 152 | + new_grammar['binaryop'], |
| 153 | + Literal(f'{date.day}')]) |
| 154 | + new_binary_expressions.extend([month_binary_expression, |
| 155 | + day_binary_expression]) |
| 156 | + |
| 157 | + new_grammar['biexpr'].members = new_grammar['biexpr'].members + tuple(new_binary_expressions) |
| 158 | + return new_grammar |
| 159 | + |
| 160 | + def _get_numeric_database_values(self, |
| 161 | + nonterminal: str) -> List[str]: |
| 162 | + return sorted([value[1] for key, value in self.linked_entities['number'].items() |
| 163 | + if value[0] == nonterminal], reverse=True) |
| 164 | + |
| 165 | + def _update_expression_reference(self, # pylint: disable=no-self-use |
| 166 | + grammar: Grammar, |
| 167 | + parent_expression_nonterminal: str, |
| 168 | + child_expression_nonterminal: str) -> None: |
| 169 | + """ |
| 170 | + When we add a new expression, there may be other expressions that refer to |
| 171 | + it, and we need to update those to point to the new expression. |
| 172 | + """ |
| 173 | + grammar[parent_expression_nonterminal].members = \ |
| 174 | + [member if member.name != child_expression_nonterminal |
| 175 | + else grammar[child_expression_nonterminal] |
| 176 | + for member in grammar[parent_expression_nonterminal].members] |
| 177 | + |
| 178 | + def _get_sequence_with_spacing(self, # pylint: disable=no-self-use |
| 179 | + new_grammar, |
| 180 | + expressions: List[Expression], |
| 181 | + name: str = '') -> Sequence: |
| 182 | + """ |
| 183 | + This is a helper method for generating sequences, since we often want a list of expressions |
| 184 | + with whitespaces between them. |
| 185 | + """ |
| 186 | + expressions = [subexpression |
| 187 | + for expression in expressions |
| 188 | + for subexpression in (expression, new_grammar['ws'])] |
| 189 | + return Sequence(*expressions, name=name) |
78 | 190 |
|
79 | 191 | def get_valid_actions(self) -> Dict[str, List[str]]:
|
80 | 192 | return self.valid_actions
|
@@ -165,79 +277,14 @@ def _get_linked_entities(self) -> Dict[str, Dict[str, Tuple[str, str, List[int]]
|
165 | 277 | entity_linking_scores['string'] = string_linking_scores
|
166 | 278 | return entity_linking_scores
|
167 | 279 |
|
168 |
| - def _update_valid_actions(self) -> Dict[str, List[str]]: |
169 |
| - valid_actions = deepcopy(self.sql_table_context.get_valid_actions()) |
170 |
| - valid_actions['time_range_start'] = [] |
171 |
| - valid_actions['time_range_end'] = [] |
172 |
| - for action, value in self.linked_entities['number'].items(): |
173 |
| - valid_actions[value[0]].append(action) |
174 |
| - |
175 |
| - for date in self.dates: |
176 |
| - for biexpr_rule in [f'biexpr -> ["date_day", ".", "year", binaryop, "{date.year}"]', |
177 |
| - f'biexpr -> ["date_day", ".", "month_number", binaryop, "{date.month}"]', |
178 |
| - f'biexpr -> ["date_day", ".", "day_number", binaryop, "{date.day}"]']: |
179 |
| - if biexpr_rule not in valid_actions: |
180 |
| - valid_actions['biexpr'].append(biexpr_rule) |
181 |
| - |
182 |
| - valid_actions['ternaryexpr'] = \ |
183 |
| - ['ternaryexpr -> [col_ref, "BETWEEN", time_range_start, "AND", time_range_end]', |
184 |
| - 'ternaryexpr -> [col_ref, "NOT", "BETWEEN", time_range_start, "AND", time_range_end]'] |
185 |
| - |
186 |
| - return valid_actions |
187 |
| - |
188 | 280 | def _get_dates(self):
|
189 | 281 | dates = []
|
190 | 282 | for tokenized_utterance in self.tokenized_utterances:
|
191 | 283 | dates.extend(get_date_from_utterance(tokenized_utterance))
|
192 | 284 | return dates
|
193 | 285 |
|
194 |
| - def update_grammar_dictionary(self) -> Dict[str, List[str]]: |
195 |
| - """ |
196 |
| - We modify the ``grammar_dictionary`` with additional constraints |
197 |
| - we want for the ATIS dataset. We then add numbers to the grammar dictionary. The strings in the |
198 |
| - database are already added in by the ``SqlTableContext``. |
199 |
| - """ |
200 |
| - self.grammar_dictionary = deepcopy(self.sql_table_context.get_grammar_dictionary()) |
201 |
| - if self.dates: |
202 |
| - year_binary_expression = f'("date_day" ws "." ws "year" ws binaryop ws "{self.dates[0].year}")' |
203 |
| - self.grammar_dictionary['biexpr'].append(year_binary_expression) |
204 |
| - |
205 |
| - for date in self.dates: |
206 |
| - month_day_binary_expressions = \ |
207 |
| - [f'("date_day" ws "." ws "month_number" ws binaryop ws "{date.month}")', |
208 |
| - f'("date_day" ws "." ws "day_number" ws binaryop ws "{date.day}")'] |
209 |
| - self.grammar_dictionary['biexpr'].extend(month_day_binary_expressions) |
210 |
| - |
211 |
| - |
212 |
| - self.grammar_dictionary['ternaryexpr'] = \ |
213 |
| - ['(col_ref ws "not" ws "BETWEEN" ws time_range_start ws "AND" ws time_range_end ws)', |
214 |
| - '(col_ref ws "NOT" ws "BETWEEN" ws time_range_start ws "AND" ws time_range_end ws)', |
215 |
| - '(col_ref ws "BETWEEN" ws time_range_start ws "AND" ws time_range_end ws)'] |
216 |
| - |
217 |
| - # We need to add the numbers, starting, ending time ranges to the grammar. |
218 |
| - numbers = sorted([value[1] for key, value in self.linked_entities['number'].items() |
219 |
| - if value[0] == 'number'], reverse=True) |
220 |
| - self.grammar_dictionary['number'] = [f'"{number}"' for number in numbers] |
221 |
| - |
222 |
| - time_range_start = sorted([value[1] for key, value in self.linked_entities['number'].items() |
223 |
| - if value[0] == 'time_range_start'], reverse=True) |
224 |
| - self.grammar_dictionary['time_range_start'] = [f'"{time}"' for time in time_range_start] |
225 |
| - |
226 |
| - time_range_end = sorted([value[1] for key, value in self.linked_entities['number'].items() |
227 |
| - if value[0] == 'time_range_end'], reverse=True) |
228 |
| - self.grammar_dictionary['time_range_end'] = [f'"{time}"' for time in time_range_end] |
229 |
| - return self.grammar_dictionary |
230 |
| - |
231 |
| - def get_grammar_string(self) -> str: |
232 |
| - """ |
233 |
| - Generate a string that can be used to instantiate a ``Grammar`` object. The string is a sequence |
234 |
| - of rules that define the grammar. |
235 |
| - """ |
236 |
| - return '\n'.join([f"{nonterminal} = {' / '.join(right_hand_side)}" |
237 |
| - for nonterminal, right_hand_side in self.grammar_dictionary.items()]) |
238 |
| - |
239 | 286 | def get_action_sequence(self, query: str) -> List[str]:
|
240 |
| - sql_visitor = SqlVisitor(self.grammar_with_context, keywords_to_uppercase=KEYWORDS) |
| 287 | + sql_visitor = SqlVisitor(self.grammar, keywords_to_uppercase=KEYWORDS) |
241 | 288 | if query:
|
242 | 289 | action_sequence = sql_visitor.parse(query)
|
243 | 290 | return action_sequence
|
@@ -277,6 +324,5 @@ def __eq__(self, other):
|
277 | 324 | if isinstance(self, other.__class__):
|
278 | 325 | return all([self.valid_actions == other.valid_actions,
|
279 | 326 | numpy.array_equal(self.linking_scores, other.linking_scores),
|
280 |
| - self.utterances == other.utterances, |
281 |
| - self.grammar_string == other.grammar_string]) |
| 327 | + self.utterances == other.utterances]) |
282 | 328 | return False
|
0 commit comments