Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit a7da2ab

Browse files
authored
Fast grammar generation (#1852)
* 1 grammar object * test, pylint * fix typo * pylint * remove pos from spacey model * private method, decrease max_decoding_steps * test helper methods
1 parent d8b13e0 commit a7da2ab

File tree

5 files changed

+150
-79
lines changed

5 files changed

+150
-79
lines changed

allennlp/data/dataset_readers/semantic_parsing/atis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self,
7171
database_file: str = None) -> None:
7272
super().__init__(lazy)
7373
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
74-
self._tokenizer = tokenizer or WordTokenizer(SpacyWordSplitter(pos_tags=True))
74+
self._tokenizer = tokenizer or WordTokenizer(SpacyWordSplitter())
7575
self._database_file = database_file
7676
# TODO(kevin): Add a keep_unparseable_utterances flag so that during validation, we do not skip queries that
7777
# cannot be parsed.

allennlp/semparse/worlds/atis_world.py

+120-74
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from typing import List, Dict, Tuple, Set, Callable
2-
from copy import deepcopy
2+
from copy import copy
33
import numpy
44
from nltk import ngrams
55

66
from parsimonious.grammar import Grammar
7+
from parsimonious.expressions import Expression, OneOf, Sequence, Literal
78

89
from allennlp.semparse.contexts.atis_tables import * # pylint: disable=wildcard-import,unused-wildcard-import
910
from allennlp.semparse.contexts.atis_sql_table_context import AtisSqlTableContext, KEYWORDS
10-
from allennlp.semparse.contexts.sql_context_utils import SqlVisitor, format_action
11+
from allennlp.semparse.contexts.sql_context_utils import SqlVisitor, format_action, initialize_valid_actions
1112

1213
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
1314

@@ -67,14 +68,125 @@ def __init__(self,
6768
self.linked_entities = self._get_linked_entities()
6869
self.dates = self._get_dates()
6970

70-
self.valid_actions: Dict[str, List[str]] = self._update_valid_actions()
7171
entities, linking_scores = self._flatten_entities()
7272
# This has shape (num_entities, num_utterance_tokens).
7373
self.linking_scores: numpy.ndarray = linking_scores
7474
self.entities: List[str] = entities
75-
self.grammar_dictionary = self.update_grammar_dictionary()
76-
self.grammar_string: str = self.get_grammar_string()
77-
self.grammar_with_context: Grammar = Grammar(self.grammar_string)
75+
self.grammar: Grammar = self._update_grammar()
76+
self.valid_actions = initialize_valid_actions(self.grammar,
77+
KEYWORDS)
78+
79+
def _update_grammar(self):
80+
"""
81+
We create a new ``Grammar`` object from the one in ``AtisSqlTableContext``, that also
82+
has the new entities that are extracted from the utterance. Stitching together the expressions
83+
to form the grammar is a little tedious here, but it is worth it because we don't have to create
84+
a new grammar from scratch. Creating a new grammar is expensive because we have many production
85+
rules that have all database values in the column on the right hand side. We update the expressions
86+
bottom up, since the higher level expressions may refer to the lower level ones. For example, the
87+
ternary expression will refer to the start and end times.
88+
"""
89+
90+
# This will give us a shallow copy, but that's OK because everything
91+
# inside is immutable so we get a new copy of it.
92+
new_grammar = copy(AtisWorld.sql_table_context.grammar)
93+
94+
numbers = self._get_numeric_database_values('number')
95+
number_literals = [Literal(number) for number in numbers]
96+
new_grammar['number'] = OneOf(*number_literals, name='number')
97+
self._update_expression_reference(new_grammar, 'pos_value', 'number')
98+
99+
time_range_start = self._get_numeric_database_values('time_range_start')
100+
time_range_start_literals = [Literal(time) for time in time_range_start]
101+
new_grammar['time_range_start'] = OneOf(*time_range_start_literals, name='time_range_start')
102+
103+
time_range_end = self._get_numeric_database_values('time_range_end')
104+
time_range_end_literals = [Literal(time) for time in time_range_end]
105+
new_grammar['time_range_end'] = OneOf(*time_range_end_literals, name='time_range_end')
106+
107+
ternary_expressions = [self._get_sequence_with_spacing(new_grammar,
108+
[new_grammar['col_ref'],
109+
Literal('BETWEEN'),
110+
new_grammar['time_range_start'],
111+
Literal(f'AND'),
112+
new_grammar['time_range_end']]),
113+
self._get_sequence_with_spacing(new_grammar,
114+
[new_grammar['col_ref'],
115+
Literal('NOT'),
116+
Literal('BETWEEN'),
117+
new_grammar['time_range_start'],
118+
Literal(f'AND'),
119+
new_grammar['time_range_end']]),
120+
self._get_sequence_with_spacing(new_grammar,
121+
[new_grammar['col_ref'],
122+
Literal('not'),
123+
Literal('BETWEEN'),
124+
new_grammar['time_range_start'],
125+
Literal(f'AND'),
126+
new_grammar['time_range_end']])]
127+
128+
new_grammar['ternaryexpr'] = OneOf(*ternary_expressions, name='ternaryexpr')
129+
self._update_expression_reference(new_grammar, 'condition', 'ternaryexpr')
130+
131+
if self.dates:
132+
new_binary_expressions = []
133+
year_binary_expression = self._get_sequence_with_spacing(new_grammar,
134+
[Literal('date_day'),
135+
Literal('.'),
136+
Literal('year'),
137+
new_grammar['binaryop'],
138+
Literal(f'{self.dates[0].year}')])
139+
new_binary_expressions.append(year_binary_expression)
140+
for date in self.dates:
141+
month_binary_expression = self._get_sequence_with_spacing(new_grammar,
142+
[Literal('date_day'),
143+
Literal('.'),
144+
Literal('month_number'),
145+
new_grammar['binaryop'],
146+
Literal(f'{date.month}')])
147+
148+
day_binary_expression = self._get_sequence_with_spacing(new_grammar,
149+
[Literal('date_day'),
150+
Literal('.'),
151+
Literal('day_number'),
152+
new_grammar['binaryop'],
153+
Literal(f'{date.day}')])
154+
new_binary_expressions.extend([month_binary_expression,
155+
day_binary_expression])
156+
157+
new_grammar['biexpr'].members = new_grammar['biexpr'].members + tuple(new_binary_expressions)
158+
return new_grammar
159+
160+
def _get_numeric_database_values(self,
161+
nonterminal: str) -> List[str]:
162+
return sorted([value[1] for key, value in self.linked_entities['number'].items()
163+
if value[0] == nonterminal], reverse=True)
164+
165+
def _update_expression_reference(self, # pylint: disable=no-self-use
166+
grammar: Grammar,
167+
parent_expression_nonterminal: str,
168+
child_expression_nonterminal: str) -> None:
169+
"""
170+
When we add a new expression, there may be other expressions that refer to
171+
it, and we need to update those to point to the new expression.
172+
"""
173+
grammar[parent_expression_nonterminal].members = \
174+
[member if member.name != child_expression_nonterminal
175+
else grammar[child_expression_nonterminal]
176+
for member in grammar[parent_expression_nonterminal].members]
177+
178+
def _get_sequence_with_spacing(self, # pylint: disable=no-self-use
179+
new_grammar,
180+
expressions: List[Expression],
181+
name: str = '') -> Sequence:
182+
"""
183+
This is a helper method for generating sequences, since we often want a list of expressions
184+
with whitespaces between them.
185+
"""
186+
expressions = [subexpression
187+
for expression in expressions
188+
for subexpression in (expression, new_grammar['ws'])]
189+
return Sequence(*expressions, name=name)
78190

79191
def get_valid_actions(self) -> Dict[str, List[str]]:
80192
return self.valid_actions
@@ -165,79 +277,14 @@ def _get_linked_entities(self) -> Dict[str, Dict[str, Tuple[str, str, List[int]]
165277
entity_linking_scores['string'] = string_linking_scores
166278
return entity_linking_scores
167279

168-
def _update_valid_actions(self) -> Dict[str, List[str]]:
169-
valid_actions = deepcopy(self.sql_table_context.get_valid_actions())
170-
valid_actions['time_range_start'] = []
171-
valid_actions['time_range_end'] = []
172-
for action, value in self.linked_entities['number'].items():
173-
valid_actions[value[0]].append(action)
174-
175-
for date in self.dates:
176-
for biexpr_rule in [f'biexpr -> ["date_day", ".", "year", binaryop, "{date.year}"]',
177-
f'biexpr -> ["date_day", ".", "month_number", binaryop, "{date.month}"]',
178-
f'biexpr -> ["date_day", ".", "day_number", binaryop, "{date.day}"]']:
179-
if biexpr_rule not in valid_actions:
180-
valid_actions['biexpr'].append(biexpr_rule)
181-
182-
valid_actions['ternaryexpr'] = \
183-
['ternaryexpr -> [col_ref, "BETWEEN", time_range_start, "AND", time_range_end]',
184-
'ternaryexpr -> [col_ref, "NOT", "BETWEEN", time_range_start, "AND", time_range_end]']
185-
186-
return valid_actions
187-
188280
def _get_dates(self):
189281
dates = []
190282
for tokenized_utterance in self.tokenized_utterances:
191283
dates.extend(get_date_from_utterance(tokenized_utterance))
192284
return dates
193285

194-
def update_grammar_dictionary(self) -> Dict[str, List[str]]:
195-
"""
196-
We modify the ``grammar_dictionary`` with additional constraints
197-
we want for the ATIS dataset. We then add numbers to the grammar dictionary. The strings in the
198-
database are already added in by the ``SqlTableContext``.
199-
"""
200-
self.grammar_dictionary = deepcopy(self.sql_table_context.get_grammar_dictionary())
201-
if self.dates:
202-
year_binary_expression = f'("date_day" ws "." ws "year" ws binaryop ws "{self.dates[0].year}")'
203-
self.grammar_dictionary['biexpr'].append(year_binary_expression)
204-
205-
for date in self.dates:
206-
month_day_binary_expressions = \
207-
[f'("date_day" ws "." ws "month_number" ws binaryop ws "{date.month}")',
208-
f'("date_day" ws "." ws "day_number" ws binaryop ws "{date.day}")']
209-
self.grammar_dictionary['biexpr'].extend(month_day_binary_expressions)
210-
211-
212-
self.grammar_dictionary['ternaryexpr'] = \
213-
['(col_ref ws "not" ws "BETWEEN" ws time_range_start ws "AND" ws time_range_end ws)',
214-
'(col_ref ws "NOT" ws "BETWEEN" ws time_range_start ws "AND" ws time_range_end ws)',
215-
'(col_ref ws "BETWEEN" ws time_range_start ws "AND" ws time_range_end ws)']
216-
217-
# We need to add the numbers, starting, ending time ranges to the grammar.
218-
numbers = sorted([value[1] for key, value in self.linked_entities['number'].items()
219-
if value[0] == 'number'], reverse=True)
220-
self.grammar_dictionary['number'] = [f'"{number}"' for number in numbers]
221-
222-
time_range_start = sorted([value[1] for key, value in self.linked_entities['number'].items()
223-
if value[0] == 'time_range_start'], reverse=True)
224-
self.grammar_dictionary['time_range_start'] = [f'"{time}"' for time in time_range_start]
225-
226-
time_range_end = sorted([value[1] for key, value in self.linked_entities['number'].items()
227-
if value[0] == 'time_range_end'], reverse=True)
228-
self.grammar_dictionary['time_range_end'] = [f'"{time}"' for time in time_range_end]
229-
return self.grammar_dictionary
230-
231-
def get_grammar_string(self) -> str:
232-
"""
233-
Generate a string that can be used to instantiate a ``Grammar`` object. The string is a sequence
234-
of rules that define the grammar.
235-
"""
236-
return '\n'.join([f"{nonterminal} = {' / '.join(right_hand_side)}"
237-
for nonterminal, right_hand_side in self.grammar_dictionary.items()])
238-
239286
def get_action_sequence(self, query: str) -> List[str]:
240-
sql_visitor = SqlVisitor(self.grammar_with_context, keywords_to_uppercase=KEYWORDS)
287+
sql_visitor = SqlVisitor(self.grammar, keywords_to_uppercase=KEYWORDS)
241288
if query:
242289
action_sequence = sql_visitor.parse(query)
243290
return action_sequence
@@ -277,6 +324,5 @@ def __eq__(self, other):
277324
if isinstance(self, other.__class__):
278325
return all([self.valid_actions == other.valid_actions,
279326
numpy.array_equal(self.linking_scores, other.linking_scores),
280-
self.utterances == other.utterances,
281-
self.grammar_string == other.grammar_string])
327+
self.utterances == other.utterances])
282328
return False

allennlp/tests/data/dataset_readers/semantic_parsing/atis_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def test_atis_read_from_file(self):
3232
assert isinstance(instance.fields['world'].as_tensor({}), AtisWorld)
3333

3434
world = instance.fields['world'].metadata
35-
assert world.valid_actions['number'] == \
36-
['number -> ["1"]',
37-
'number -> ["0"]']
35+
assert set(world.valid_actions['number']) == \
36+
{'number -> ["1"]',
37+
'number -> ["0"]'}
3838

3939
assert world.linked_entities['string']['airport_airport_code_string -> ["\'DTW\'"]'][2] == \
4040
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] # ``detroit`` -> ``DTW``

allennlp/tests/fixtures/semantic_parsing/atis/experiment.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"decoder_beam_search": {
2626
"beam_size": 5
2727
},
28-
"max_decoding_steps": 100,
28+
"max_decoding_steps": 10,
2929
"input_attention": {"type": "dot_product"},
3030
"dropout": 0.5,
3131
"database_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/atis/atis.db"

allennlp/tests/semparse/worlds/atis_world_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# pylint: disable=too-many-lines
2+
from datetime import datetime
23
import json
34

5+
from parsimonious.expressions import Literal, Sequence
6+
47
from allennlp.common.file_utils import cached_path
58
from allennlp.semparse.contexts.atis_tables import * # pylint: disable=wildcard-import,unused-wildcard-import
69
from allennlp.common.testing import AllenNlpTestCase
@@ -734,3 +737,25 @@ def test_time_extraction(self): # pylint: disable=no-self-use
734737
pm_times = [pm_map_match_to_query_value(string)
735738
for string in ['12pm', '1pm', '830pm', '1230pm', '115pm']]
736739
assert pm_times == [[1200], [1300], [2030], [1230], [1315]]
740+
741+
def test_atis_helper_methods(self): # pylint: disable=no-self-use
742+
world = AtisWorld([("what is the earliest flight in morning "
743+
"1993 june fourth from boston to pittsburgh")])
744+
assert world.dates == [datetime(1993, 6, 4, 0, 0)]
745+
assert world._get_numeric_database_values('time_range_end') == ['1200'] # pylint: disable=protected-access
746+
assert world._get_sequence_with_spacing(world.grammar, # pylint: disable=protected-access
747+
[world.grammar['col_ref'],
748+
Literal('BETWEEN'),
749+
world.grammar['time_range_start'],
750+
Literal(f'AND'),
751+
world.grammar['time_range_end']]) == \
752+
Sequence(world.grammar['col_ref'],
753+
world.grammar['ws'],
754+
Literal('BETWEEN'),
755+
world.grammar['ws'],
756+
world.grammar['time_range_start'],
757+
world.grammar['ws'],
758+
Literal(f'AND'),
759+
world.grammar['ws'],
760+
world.grammar['time_range_end'],
761+
world.grammar['ws'])

0 commit comments

Comments
 (0)