Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 27fab84

Browse files
authored
Add more configuration options for ATIS semantic parser (#1821)
* costs * add keep if unparseable flag * keep unparseable in dev * pylint * add one direction cost * fix test * retrain fixture * docs * read unparseable * add concat based context * fix imports * heuristics * heuristics * unparseable queries * turn off info logging in subprocess * fix flight numbers * fix dates * fix tokenization * heuristics * more epochs * reverse the productions, predict the leftmost nonterminal first * left first model 44 * add helper functions for numbers * fix global rules test * fix tests * clean up * grammar statelet test * dates * rename * pylint * remove debug tests * retrain fixture * fix action sequence test * change text2sql test to also left first * add another test extraction test * experiment config * user numeric nonterminals * docs
1 parent dc66c8f commit 27fab84

File tree

19 files changed

+786
-387
lines changed

19 files changed

+786
-387
lines changed

allennlp/data/dataset_readers/semantic_parsing/atis.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
1717

1818
from allennlp.semparse.worlds.atis_world import AtisWorld
19+
from allennlp.semparse.contexts.atis_sql_table_context import NUMERIC_NONTERMINALS
1920

2021
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
2122

23+
END_OF_UTTERANCE_TOKEN = "@@EOU@@"
24+
2225
def _lazy_parse(text: str):
2326
for interaction in text.split("\n"):
2427
if interaction:
@@ -63,18 +66,22 @@ class AtisDatasetReader(DatasetReader):
6366
database_file: ``str``, optional
6467
The directory to find the sqlite database file. We query the sqlite database to find the strings
6568
that are allowed.
69+
num_turns_to_concatenate: ``str``, optional
70+
The number of utterances to concatenate as the conversation context.
6671
"""
6772
def __init__(self,
6873
token_indexers: Dict[str, TokenIndexer] = None,
74+
keep_if_unparseable: bool = False,
6975
lazy: bool = False,
7076
tokenizer: Tokenizer = None,
71-
database_file: str = None) -> None:
77+
database_file: str = None,
78+
num_turns_to_concatenate: int = 1) -> None:
7279
super().__init__(lazy)
80+
self._keep_if_unparseable = keep_if_unparseable
7381
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
7482
self._tokenizer = tokenizer or WordTokenizer(SpacyWordSplitter())
7583
self._database_file = database_file
76-
# TODO(kevin): Add a keep_unparseable_utterances flag so that during validation, we do not skip queries that
77-
# cannot be parsed.
84+
self._num_turns_to_concatenate = num_turns_to_concatenate
7885

7986
@overrides
8087
def _read(self, file_path: str):
@@ -108,6 +115,9 @@ def text_to_instance(self, # type: ignore
108115
sql_query_labels: ``List[str]``, optional
109116
The SQL queries that are given as labels during training or validation.
110117
"""
118+
if self._num_turns_to_concatenate:
119+
utterances[-1] = f' {END_OF_UTTERANCE_TOKEN} '.join(utterances[-self._num_turns_to_concatenate:])
120+
111121
utterance = utterances[-1]
112122
action_sequence: List[str] = []
113123

@@ -149,21 +159,21 @@ def text_to_instance(self, # type: ignore
149159

150160
if sql_query_labels != None:
151161
fields['sql_queries'] = MetadataField(sql_query_labels)
152-
if action_sequence:
162+
if action_sequence and not self._keep_if_unparseable:
153163
for production_rule in action_sequence:
154164
index_fields.append(IndexField(action_map[production_rule], action_field))
155-
156165
action_sequence_field = ListField(index_fields)
157166
fields['target_action_sequence'] = action_sequence_field
158-
else:
159-
# If we are given a SQL query, but we are unable to parse it, then we will skip it.
167+
elif not self._keep_if_unparseable:
168+
# If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly
169+
# to keep it, then we will skip the it.
160170
return None
161171

162172
return Instance(fields)
163173

164174
@staticmethod
165175
def _is_global_rule(nonterminal: str) -> bool:
166-
if nonterminal in ['number', 'time_range_start', 'time_range_end']:
176+
if nonterminal in NUMERIC_NONTERMINALS:
167177
return False
168178
elif nonterminal.endswith('string'):
169179
return False

allennlp/models/semantic_parsing/atis/atis_semantic_parser.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from allennlp.modules import Attention, Seq2SeqEncoder, TextFieldEmbedder, Embedding
1515
from allennlp.nn import util
1616
from allennlp.semparse.worlds import AtisWorld
17+
from allennlp.semparse.contexts.atis_sql_table_context import NUMERIC_NONTERMINALS
1718
from allennlp.semparse.contexts.sql_context_utils import action_sequence_to_sql
1819
from allennlp.state_machines.states import GrammarBasedState
1920
from allennlp.state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction
@@ -326,7 +327,8 @@ def _get_type_vector(worlds: List[AtisWorld],
326327
for batch_index, world in enumerate(worlds):
327328
types = []
328329
entities = [('number', entity)
329-
if 'number' or 'time_range' in entity
330+
if any([entity.startswith(numeric_nonterminal)
331+
for numeric_nonterminal in NUMERIC_NONTERMINALS])
330332
else ('string', entity)
331333
for entity in world.entities]
332334

@@ -475,8 +477,7 @@ def _create_grammar_state(self,
475477

476478
return GrammarStatelet(['statement'],
477479
translated_valid_actions,
478-
self.is_nonterminal,
479-
reverse_productions=False)
480+
self.is_nonterminal)
480481

481482
@overrides
482483
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

allennlp/semparse/contexts/atis_sql_table_context.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@
3232
GRAMMAR_DICTIONARY = {}
3333
GRAMMAR_DICTIONARY['statement'] = ['query ws ";" ws']
3434
GRAMMAR_DICTIONARY['query'] = ['(ws "(" ws "SELECT" ws distinct ws select_results ws '
35+
'"FROM" ws table_refs ws where_clause ws group_by_clause ws ")" ws)',
36+
'(ws "(" ws "SELECT" ws distinct ws select_results ws '
3537
'"FROM" ws table_refs ws where_clause ws ")" ws)',
3638
'(ws "SELECT" ws distinct ws select_results ws '
3739
'"FROM" ws table_refs ws where_clause ws)']
3840
GRAMMAR_DICTIONARY['select_results'] = ['col_refs', 'agg']
39-
GRAMMAR_DICTIONARY['agg'] = ['agg_func ws "(" ws col_ref ws ")"']
41+
GRAMMAR_DICTIONARY['agg'] = ['( agg_func ws "(" ws col_ref ws ")" )', '(agg_func ws "(" ws col ws ")" )']
4042
GRAMMAR_DICTIONARY['agg_func'] = ['"MIN"', '"min"', '"MAX"', '"max"', '"COUNT"', '"count"']
4143
GRAMMAR_DICTIONARY['col_refs'] = ['(col_ref ws "," ws col_refs)', '(col_ref)']
4244
GRAMMAR_DICTIONARY['table_refs'] = ['(table_name ws "," ws table_refs)', '(table_name)']
4345
GRAMMAR_DICTIONARY['where_clause'] = ['("WHERE" ws "(" ws conditions ws ")" ws)', '("WHERE" ws conditions ws)']
46+
GRAMMAR_DICTIONARY['group_by_clause'] = ['("GROUP" ws "BY" ws col_ref)']
4447
GRAMMAR_DICTIONARY['conditions'] = ['(condition ws conj ws conditions)',
4548
'(condition ws conj ws "(" ws conditions ws ")")',
4649
'("(" ws conditions ws ")" ws conj ws conditions)',
@@ -71,6 +74,10 @@
7174
KEYWORDS = ['"SELECT"', '"FROM"', '"MIN"', '"MAX"', '"COUNT"', '"WHERE"', '"NOT"', '"IN"', '"LIKE"',
7275
'"IS"', '"BETWEEN"', '"AND"', '"ALL"', '"ANY"', '"NULL"', '"OR"', '"DISTINCT"']
7376

77+
NUMERIC_NONTERMINALS = ['number', 'time_range_start', 'time_range_end',
78+
'fare_round_trip_cost', 'fare_one_direction_cost',
79+
'flight_number', 'day_number', 'month_number', 'year_number']
80+
7481
class AtisSqlTableContext:
7582
"""
7683
An ``AtisSqlTableContext`` represents the SQL context with a grammar of SQL and the valid actions
@@ -123,11 +130,14 @@ def create_grammar_dict_and_strings(self) -> Tuple[Dict[str, List[str]], List[Tu
123130
grammar_dictionary['table_name'] = \
124131
sorted([f'"{table}"'
125132
for table in list(self.all_tables.keys())], reverse=True)
126-
grammar_dictionary['col_ref'] = ['"*"']
133+
grammar_dictionary['col_ref'] = ['"*"', 'agg']
134+
all_columns = []
127135
for table, columns in self.all_tables.items():
128136
grammar_dictionary['col_ref'].extend([f'("{table}" ws "." ws "{column}")'
129137
for column in columns])
138+
all_columns.extend(columns)
130139
grammar_dictionary['col_ref'] = sorted(grammar_dictionary['col_ref'], reverse=True)
140+
grammar_dictionary['col'] = sorted([f'"{column}"' for column in all_columns], reverse=True)
131141

132142
biexprs = []
133143
if self.tables_with_strings:
@@ -138,6 +148,9 @@ def create_grammar_dict_and_strings(self) -> Tuple[Dict[str, List[str]], List[Tu
138148
self.cursor.execute(f'SELECT DISTINCT {table} . {column} FROM {table}')
139149
results = self.cursor.fetchall()
140150

151+
# Almost all the query values are in the database, we hardcode the rare case here.
152+
if table == 'flight' and column == 'airline_code':
153+
results.append(('EA',))
141154
strings_list.extend([(format_action(f"{table}_{column}_string",
142155
str(row[0]),
143156
is_string=not 'number' in column,

allennlp/semparse/contexts/atis_tables.py

+85-16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
APPROX_WORDS = ['about', 'around', 'approximately']
1616
WORDS_PRECEDING_TIME = ['at', 'between', 'to', 'before', 'after']
1717

18+
1819
def pm_map_match_to_query_value(match: str):
1920
if len(match.rstrip('pm')) < 3: # This will match something like ``5pm``.
2021
if match.startswith('12'):
@@ -82,12 +83,13 @@ def get_date_from_utterance(tokenized_utterance: List[Token],
8283
it is 1993 so we do the same here. If there is no mention of the month or day then
8384
we do not return any dates from the utterance.
8485
"""
86+
8587
dates = []
88+
8689
utterance = ' '.join([token.text for token in tokenized_utterance])
8790
year_result = re.findall(r'199[0-4]', utterance)
8891
if year_result:
8992
year = int(year_result[0])
90-
9193
trigrams = ngrams([token.text for token in tokenized_utterance], 3)
9294
for month, tens, digit in trigrams:
9395
# This will match something like ``september twenty first``.
@@ -107,6 +109,20 @@ def get_date_from_utterance(tokenized_utterance: List[Token],
107109
except ValueError:
108110
print('invalid month day')
109111

112+
fivegrams = ngrams([token.text for token in tokenized_utterance], 5)
113+
for tens, digit, _, year_match, month in fivegrams:
114+
# This will match something like ``twenty first of 1993 july``.
115+
day = ' '.join([tens, digit])
116+
if month in MONTH_NUMBERS and day in DAY_NUMBERS and year_match.isdigit():
117+
try:
118+
dates.append(datetime(int(year_match), MONTH_NUMBERS[month], DAY_NUMBERS[day]))
119+
except ValueError:
120+
print('invalid month day')
121+
if month in MONTH_NUMBERS and digit in DAY_NUMBERS and year_match.isdigit():
122+
try:
123+
dates.append(datetime(int(year_match), MONTH_NUMBERS[month], DAY_NUMBERS[digit]))
124+
except ValueError:
125+
print('invalid month day')
110126
return dates
111127

112128
def get_numbers_from_utterance(utterance: str, tokenized_utterance: List[Token]) -> Dict[str, List[int]]:
@@ -189,6 +205,35 @@ def get_time_range_end_from_utterance(utterance: str, # pylint: disable=unused-a
189205

190206
return time_range_end_linking_dict
191207

208+
def get_costs_from_utterance(utterance: str, # pylint: disable=unused-argument
209+
tokenized_utterance: List[Token]) -> Dict[str, List[int]]:
210+
dollars_indices = {index for index, token in enumerate(tokenized_utterance)
211+
if token.text == 'dollars' or token.text == 'dollar'}
212+
213+
costs_linking_dict: Dict[str, List[int]] = defaultdict(list)
214+
for token_index, token in enumerate(tokenized_utterance):
215+
if token_index + 1 in dollars_indices and token.text.isdigit():
216+
costs_linking_dict[token.text].append(token_index)
217+
return costs_linking_dict
218+
219+
def get_flight_numbers_from_utterance(utterance: str, # pylint: disable=unused-argument
220+
tokenized_utterance: List[Token]) -> Dict[str, List[int]]:
221+
indices_words_preceding_flight_number = {index for index, token in enumerate(tokenized_utterance)
222+
if token.text in {'flight', 'number'}
223+
or token.text.upper() in AIRLINE_CODE_LIST
224+
or token.text.lower() in AIRLINE_CODES.keys()}
225+
226+
indices_words_succeeding_flight_number = {index for index, token in enumerate(tokenized_utterance)
227+
if token.text == 'flight'}
228+
229+
flight_numbers_linking_dict: Dict[str, List[int]] = defaultdict(list)
230+
for token_index, token in enumerate(tokenized_utterance):
231+
if token.text.isdigit():
232+
if token_index - 1 in indices_words_preceding_flight_number:
233+
flight_numbers_linking_dict[token.text].append(token_index)
234+
if token_index + 1 in indices_words_succeeding_flight_number:
235+
flight_numbers_linking_dict[token.text].append(token_index)
236+
return flight_numbers_linking_dict
192237

193238
def digit_to_query_time(digit: str) -> List[int]:
194239
"""
@@ -303,6 +348,7 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
303348
'mgm': ['MG'],
304349
'midwest': ['YX'],
305350
'nation': ['NX'],
351+
'nationair': ['NX'],
306352
'northeast': ['2V'],
307353
'northwest': ['NW'],
308354
'ontario': ['GX'],
@@ -384,11 +430,14 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
384430
GROUND_SERVICE = {'air taxi': ['AIR TAXI OPERATION'],
385431
'car': ['RENTAL CAR'],
386432
'limo': ['LIMOUSINE'],
433+
'limousine': ['LIMOUSINE'],
387434
'rapid': ['RAPID TRANSIT'],
388435
'rental': ['RENTAL CAR'],
389436
'taxi': ['TAXI']}
390437

391-
MISC_STR = {"every day" : ["DAILY"]}
438+
MISC_STR = {"every day" : ["DAILY"],
439+
"saint petersburg": ["ST. PETERSBURG"],
440+
"saint louis": ["ST. LOUIS"]}
392441

393442
DAY_NUMBERS = {'first': 1,
394443
'second': 2,
@@ -424,18 +473,27 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
424473

425474

426475
MISC_TIME_TRIGGERS = {'lunch': ['1400'],
427-
'noon': ['1200']}
476+
'noon': ['1200'],
477+
'early evening': ['1800', '2000'],
478+
'morning': ['0', '1200'],
479+
'night': ['1800', '2400']}
428480

429481
TIME_RANGE_START_DICT = {'morning': ['0'],
482+
'mornings': ['1200'],
430483
'afternoon': ['1200'],
484+
'afternoons': ['1200'],
485+
'after noon': ['1200'],
431486
'late afternoon': ['1600'],
432487
'evening': ['1800'],
433488
'late evening': ['2000']}
434489

435490
TIME_RANGE_END_DICT = {'early morning': ['800'],
436-
'morning': ['1200'],
491+
'morning': ['1200', '800'],
492+
'mornings': ['1200', '800'],
437493
'early afternoon': ['1400'],
438494
'afternoon': ['1800'],
495+
'afternoons': ['1800'],
496+
'after noon': ['1800'],
439497
'evening': ['2200']}
440498

441499
ALL_TABLES = {'aircraft': ['aircraft_code', 'aircraft_description', 'capacity',
@@ -477,18 +535,18 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
477535

478536
TABLES_WITH_STRINGS = {'airline' : ['airline_code', 'airline_name'],
479537
'city' : ['city_name', 'state_code', 'city_code'],
480-
'fare' : ['round_trip_required', 'fare_basis_code'],
481-
'flight' : ['airline_code', 'flight_days', 'flight_number'],
538+
'fare' : ['round_trip_required', 'fare_basis_code', 'restriction_code'],
539+
'flight' : ['airline_code', 'flight_days'],
482540
'flight_stop' : ['stop_airport'],
483-
'airport' : ['airport_code'],
484-
'state' : ['state_name'],
485-
'fare_basis' : ['fare_basis_code', 'class_type', 'economy'],
486-
'class_of_service' : ['booking_class'],
487-
'aircraft' : ['basic_type', 'manufacturer'],
541+
'airport' : ['airport_code', 'airport_name'],
542+
'state' : ['state_name', 'state_code'],
543+
'fare_basis' : ['fare_basis_code', 'class_type', 'economy', 'booking_class'],
544+
'class_of_service' : ['booking_class', 'class_description'],
545+
'aircraft' : ['basic_type', 'manufacturer', 'aircraft_code', 'propulsion'],
488546
'restriction' : ['restriction_code'],
489547
'ground_service' : ['transport_type'],
490-
'days' : ['day_name'],
491-
'food_service': ['meal_description']}
548+
'days' : ['day_name', 'days_code'],
549+
'food_service': ['meal_description', 'compartment']}
492550

493551
DAY_OF_WEEK = ['MONDAY', 'TUESDAY', 'WEDNESDAY', 'THURSDAY', 'FRIDAY', 'SATURDAY', 'SUNDAY']
494552

@@ -518,7 +576,10 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
518576
'charlotte': ['CLT'],
519577
'dallas': ['DFW'],
520578
'detroit': ['DTW'],
579+
'houston': ['IAH'],
521580
'la guardia': ['LGA'],
581+
'love field': ['DAL'],
582+
'los angeles': ['LAX'],
522583
'oakland': ['OAK'],
523584
'philadelphia': ['PHL'],
524585
'pittsburgh': ['PIT'],
@@ -537,7 +598,7 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
537598
'OK', 'DL', '9E', 'QD', 'LH', 'XJ', 'MG',
538599
'YX', 'NX', '2V', 'NW', 'RP', 'AT', 'SN',
539600
'OO', 'WN', 'TG', 'FF', '9N', 'TW', 'RZ',
540-
'UA', 'US', 'OE']
601+
'UA', 'US', 'OE', 'EA']
541602
CITIES = ['NASHVILLE', 'BOSTON', 'BURBANK', 'BALTIMORE', 'CHICAGO', 'CLEVELAND',
542603
'CHARLOTTE', 'COLUMBUS', 'CINCINNATI', 'DENVER', 'DALLAS', 'DETROIT',
543604
'FORT WORTH', 'HOUSTON', 'WESTCHESTER COUNTY', 'INDIANAPOLIS', 'NEWARK',
@@ -551,7 +612,12 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
551612
'MATL', 'MMEM', 'MMIA', 'MMKC', 'MMKE', 'MMSP', 'NNYC', 'OOAK', 'OONT', 'OORL',
552613
'PPHL', 'PPHX', 'PPIT', 'SMSP', 'SSAN', 'SSEA', 'SSFO', 'SSJC', 'SSLC', 'SSTL',
553614
'STPA', 'TSEA', 'TTPA', 'WWAS', 'YYMQ', 'YYTO']
554-
CLASS = ['COACH', 'BUSINESS', 'FIRST', 'THRIST', 'STANDARD', 'SHUTTLE']
615+
616+
CLASS = ['COACH', 'BUSINESS', 'FIRST', 'THRIFT', 'STANDARD', 'SHUTTLE']
617+
618+
AIRCRAFT_MANUFACTURERS = ['BOEING', 'MCDONNELL DOUGLAS', 'FOKKER']
619+
620+
AIRCRAFT_BASIC_CODES = ['DC9', '737', '767', '747', 'DC10', '757', 'MD80']
555621

556622
DAY_OF_WEEK_INDEX = {idx : [day] for idx, day in enumerate(DAY_OF_WEEK)}
557623

@@ -560,7 +626,10 @@ def convert_to_string_list_value_dict(trigger_dict: Dict[str, int]) -> Dict[str,
560626
FARE_BASIS_CODE, CLASS,
561627
AIRLINE_CODE_LIST, DAY_OF_WEEK,
562628
CITY_CODE_LIST, MEALS,
563-
RESTRICT_CODES]
629+
RESTRICT_CODES,
630+
AIRCRAFT_MANUFACTURERS,
631+
AIRCRAFT_BASIC_CODES]
632+
564633
TRIGGER_DICTS = [CITY_AIRPORT_CODES,
565634
AIRLINE_CODES,
566635
CITY_CODES,

0 commit comments

Comments
 (0)