Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 91bfb4c

Browse files
authored
Global grammar values (#1888)
1 parent ffab320 commit 91bfb4c

File tree

5 files changed

+60
-25
lines changed

5 files changed

+60
-25
lines changed

allennlp/data/dataset_readers/dataset_utils/text2sql_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def read_dataset_schema(schema_path: str) -> Dict[str, List[TableColumn]]:
155155
mapping table names to their columns and respective types.
156156
This handles columns in an arbitrary order and also allows
157157
either ``{Table, Field}`` or ``{Table, Field} Name`` as headers,
158-
because both appear in the data.
158+
because both appear in the data. It also uppercases table and
159+
column names if they are not already uppercase.
159160
160161
Parameters
161162
----------
@@ -178,7 +179,7 @@ def read_dataset_schema(schema_path: str) -> Dict[str, List[TableColumn]]:
178179
table = data.get("Table Name", None) or data.get("Table")
179180
column = data.get("Field Name", None) or data.get("Field")
180181
is_primary_key = data.get("Primary Key") == "y"
181-
schema[table].append(TableColumn(column, data["Type"], is_primary_key))
182+
schema[table.upper()].append(TableColumn(column.upper(), data["Type"], is_primary_key))
182183

183184
return {**schema}
184185

allennlp/semparse/contexts/text2sql_table_context.py

+19
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@
9292
'">="', '"<="', '">"', '"<"', '"AND"', '"OR"', '"LIKE"']
9393
GRAMMAR_DICTIONARY["unaryop"] = ['"+"', '"-"', '"not"', '"NOT"']
9494

95+
96+
97+
GLOBAL_DATASET_VALUES: Dict[str, List[str]] = {
98+
# These are used to check values are present, or numbers of authors.
99+
"scholar": ["0", "1", "2"],
100+
# 0 is used for "sea level", 750 is a "major" lake, and 150000 is a "major" city.
101+
"geography": ["0", "750", "150000"],
102+
# This defines what an "above average" restaurant is.
103+
"restaurants": ["2.5"]
104+
}
105+
106+
95107
def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],
96108
schema: Dict[str, List[TableColumn]]) -> None:
97109
table_names = sorted([f'"{table}"' for table in
@@ -118,3 +130,10 @@ def update_grammar_with_table_values(grammar_dictionary: Dict[str, List[str]],
118130
elif column_has_numeric_type(column):
119131
productions = sorted([f'"{str(result)}"' for result in results], reverse=True)
120132
grammar_dictionary["number"].extend(productions)
133+
134+
135+
def update_grammar_with_global_values(grammar_dictionary: Dict[str, List[str]], dataset_name: str):
136+
137+
values = GLOBAL_DATASET_VALUES.get(dataset_name, [])
138+
values_for_grammar = [f'"{str(value)}"' for value in values]
139+
grammar_dictionary["value"] = values_for_grammar + grammar_dictionary["value"]

allennlp/semparse/worlds/text2sql_world.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Tuple, Dict
22
from copy import deepcopy
33
from sqlite3 import Cursor
4+
import os
45

56
from parsimonious import Grammar
67

@@ -11,6 +12,7 @@
1112
from allennlp.semparse.contexts.text2sql_table_context import GRAMMAR_DICTIONARY
1213
from allennlp.semparse.contexts.text2sql_table_context import update_grammar_with_table_values
1314
from allennlp.semparse.contexts.text2sql_table_context import update_grammar_with_tables
15+
from allennlp.semparse.contexts.text2sql_table_context import update_grammar_with_global_values
1416

1517
class Text2SqlWorld:
1618
"""
@@ -36,6 +38,7 @@ def __init__(self,
3638
use_prelinked_entities: bool = True) -> None:
3739
self.cursor = cursor
3840
self.schema = read_dataset_schema(schema_path)
41+
self.dataset_name = os.path.basename(schema_path).split("-")[0]
3942
self.use_prelinked_entities = use_prelinked_entities
4043

4144
# NOTE: This base dictionary should not be modified.
@@ -67,24 +70,27 @@ def get_action_sequence_and_all_actions(self,
6770

6871
def _initialize_grammar_dictionary(self, grammar_dictionary: Dict[str, List[str]]) -> Dict[str, List[str]]:
6972
# Add all the table and column names to the grammar.
70-
if self.schema:
71-
update_grammar_with_tables(grammar_dictionary, self.schema)
72-
73-
if self.cursor is not None and not self.use_prelinked_entities:
74-
# Now if we have strings in the table, we need to be able to
75-
# produce them, so we find all of the strings in the tables here
76-
# and create production rules from them. We only do this if
77-
# we haven't pre-linked entities, because if we have, we don't
78-
# need to be able to generate the values - just the placeholder
79-
# symbols which link to them.
80-
grammar_dictionary["number"] = []
81-
grammar_dictionary["string"] = []
82-
83-
update_grammar_with_table_values(grammar_dictionary, self.schema, self.cursor)
84-
else:
85-
# TODO(Mark): The grammar can be tightened here if we don't need to
86-
# produce concrete values.
87-
pass
73+
update_grammar_with_tables(grammar_dictionary, self.schema)
74+
75+
if self.cursor is not None and not self.use_prelinked_entities:
76+
# Now if we have strings in the table, we need to be able to
77+
# produce them, so we find all of the strings in the tables here
78+
# and create production rules from them. We only do this if
79+
# we haven't pre-linked entities, because if we have, we don't
80+
# need to be able to generate the values - just the placeholder
81+
# symbols which link to them.
82+
grammar_dictionary["number"] = []
83+
grammar_dictionary["string"] = []
84+
85+
update_grammar_with_table_values(grammar_dictionary, self.schema, self.cursor)
86+
else:
87+
# TODO(Mark): The grammar can be tightened here if we don't need to
88+
# produce concrete values.
89+
pass
90+
91+
# Finally, update the grammar with global, non-variable values
92+
# found in the dataset, if present.
93+
update_grammar_with_global_values(grammar_dictionary, self.dataset_name)
8894

8995
return grammar_dictionary
9096

allennlp/tests/data/dataset_readers/semantic_parsing/grammar_based_text2sql_test.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def test_reader_can_read_data_with_entity_pre_linking(self):
2424

2525
action_sequence = fields["action_sequence"].field_list
2626
indices = [x.sequence_index for x in action_sequence]
27-
assert indices == [101, 75, 81, 124, 33, 5, 33, 5, 33, 5, 33, 5, 33, 5,
28-
39, 115, 13, 118, 21, 27, 108, 16, 118, 21, 30, 107, 13,
29-
118, 21, 30, 108, 16, 114, 13, 118, 21, 23, 107, 46, 95,
30-
96, 94, 100, 108, 94, 100, 107, 90, 88, 80, 39, 119,
31-
48, 2, 42, 92]
27+
assert indices == [101, 75, 81, 125, 33, 5, 33, 5, 33, 5, 33, 5,
28+
33, 5, 39, 115, 13, 119, 21, 27, 108, 16, 119,
29+
21, 30, 107, 13, 119, 21, 30, 108, 16, 114, 13,
30+
119, 21, 23, 107, 46, 95, 96, 94, 100, 108,
31+
94, 100, 107, 90, 88, 80, 39, 120, 48, 2, 42, 92]
3232

3333
action_fields = fields["valid_actions"].field_list
3434
production_rules = [(x.rule, x.is_global_rule) for x in action_fields]
@@ -146,6 +146,7 @@ def test_reader_can_read_data_with_entity_pre_linking(self):
146146
('unaryop -> ["not"]', True),
147147
('value -> ["\'city_name0\'"]', True),
148148
('value -> ["\'name0\'"]', True),
149+
('value -> ["2.5"]', True),
149150
('value -> ["YEAR(CURDATE())"]', True),
150151
('value -> [boolean]', True),
151152
('value -> [col_ref]', True),

allennlp/tests/semparse/worlds/text2sql_world_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def test_world_modifies_unconstrained_grammar_correctly(self):
2525
'"RATING"', '"NAME"', '"HOUSE_NUMBER"',
2626
'"FOOD_TYPE"', '"COUNTY"', '"CITY_NAME"']
2727

28+
def test_world_modifies_grammar_with_global_values_for_dataset(self):
29+
world = Text2SqlWorld(self.schema)
30+
grammar_dictionary = world.base_grammar_dictionary
31+
# Should have added 2.5 because it is a global value
32+
# for the restaurants dataset.
33+
assert grammar_dictionary["value"] == ['"2.5"', 'parenval', '"YEAR(CURDATE())"',
34+
'number', 'boolean', 'function', 'col_ref', 'string']
35+
2836
def test_grammar_from_world_can_parse_statements(self):
2937
world = Text2SqlWorld(self.schema)
3038
sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',

0 commit comments

Comments
 (0)