Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 4c99f8e

Browse files
authored
Text2sql reader (#1738)
- Moves all semantic parsing dataset readers into their own folder. - Adds a dataset reader for the text2sql baseline which can read any of the 8 datasets. I also refactored the sql utils a bit to read from my new directory format, for which I added a script in the previous PR. This includes adding functionality to de-duplicate the questions in a given dataset, not just the SQL. This PR looks massive, but I only added `template_text2sql.py` and modified `text2sql_utils.py` - all the rest are just moving folders around and adding depreciation warnings.
1 parent 8867f2f commit 4c99f8e

21 files changed

+1082
-892
lines changed

allennlp/data/dataset_readers/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@
77
"""
88

99
# pylint: disable=line-too-long
10-
from allennlp.data.dataset_readers.atis import AtisDatasetReader
1110
from allennlp.data.dataset_readers.ccgbank import CcgBankDatasetReader
1211
from allennlp.data.dataset_readers.conll2003 import Conll2003DatasetReader
1312
from allennlp.data.dataset_readers.conll2000 import Conll2000DatasetReader
1413
from allennlp.data.dataset_readers.ontonotes_ner import OntonotesNamedEntityRecognition
1514
from allennlp.data.dataset_readers.coreference_resolution import ConllCorefReader, WinobiasReader
1615
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
1716
from allennlp.data.dataset_readers.language_modeling import LanguageModelingReader
18-
from allennlp.data.dataset_readers.nlvr import NlvrDatasetReader
1917
from allennlp.data.dataset_readers.penn_tree_bank import PennTreeBankConstituencySpanDatasetReader
2018
from allennlp.data.dataset_readers.reading_comprehension import SquadReader, TriviaQaReader, QuACReader
2119
from allennlp.data.dataset_readers.semantic_role_labeling import SrlReader
@@ -25,5 +23,6 @@
2523
from allennlp.data.dataset_readers.universal_dependencies import UniversalDependenciesDatasetReader
2624
from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import (
2725
StanfordSentimentTreeBankDatasetReader)
28-
from allennlp.data.dataset_readers.wikitables import WikiTablesDatasetReader
2926
from allennlp.data.dataset_readers.quora_paraphrase import QuoraParaphraseDatasetReader
27+
from allennlp.data.dataset_readers.semantic_parsing import (
28+
WikiTablesDatasetReader, AtisDatasetReader, NlvrDatasetReader, TemplateText2SqlDatasetReader)

allennlp/data/dataset_readers/atis.py

+5-156
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,6 @@
1-
import json
2-
from typing import Dict, List
3-
import logging
1+
# pylint: disable=unused-import
2+
import warnings
3+
from allennlp.data.dataset_readers.semantic_parsing.atis import AtisDatasetReader
44

5-
from overrides import overrides
6-
from parsimonious.exceptions import ParseError
7-
8-
from allennlp.common.file_utils import cached_path
9-
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
10-
from allennlp.data.fields import Field, ArrayField, ListField, IndexField, \
11-
ProductionRuleField, TextField, MetadataField
12-
from allennlp.data.instance import Instance
13-
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
14-
from allennlp.data.tokenizers import Tokenizer, WordTokenizer
15-
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
16-
17-
from allennlp.semparse.worlds.atis_world import AtisWorld
18-
19-
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
20-
21-
def _lazy_parse(text: str):
22-
for interaction in text.split("\n"):
23-
if interaction:
24-
yield json.loads(interaction)
25-
26-
@DatasetReader.register("atis")
27-
class AtisDatasetReader(DatasetReader):
28-
# pylint: disable=line-too-long
29-
"""
30-
This ``DatasetReader`` takes json files and converts them into ``Instances`` for the
31-
``AtisSemanticParser``.
32-
33-
Each line in the file is a JSON object that represent an interaction in the ATIS dataset
34-
that has the following keys and values:
35-
```
36-
"id": The original filepath in the LDC corpus
37-
"interaction": <list where each element represents a turn in the interaction>
38-
"scenario": A code that refers to the scenario that served as the prompt for this interaction
39-
"ut_date": Date of the interaction
40-
"zc09_path": Path that was used in the original paper `Learning Context-Dependent Mappings from
41-
Sentences to Logical Form
42-
<https://www.semanticscholar.org/paper/Learning-Context-Dependent-Mappings-from-Sentences-Zettlemoyer-Collins/44a8fcee0741139fa15862dc4b6ce1e11444878f>'_ by Zettlemoyer and Collins (ACL/IJCNLP 2009)
43-
```
44-
45-
Each element in the ``interaction`` list has the following keys and values:
46-
```
47-
"utterance": Natural language input
48-
"sql": A list of SQL queries that the utterance maps to, it could be multiple SQL queries
49-
or none at all.
50-
```
51-
52-
Parameters
53-
----------
54-
token_indexers : ``Dict[str, TokenIndexer]``, optional
55-
Token indexers for the utterances. Will default to ``{"tokens": SingleIdTokenIndexer()}``.
56-
lazy : ``bool`` (optional, default=False)
57-
Passed to ``DatasetReader``. If this is ``True``, training will start sooner, but will
58-
take longer per batch.
59-
tokenizer : ``Tokenizer``, optional
60-
Tokenizer to use for the utterances. Will default to ``WordTokenizer()`` with Spacy's tagger
61-
enabled.
62-
database_directory : ``str``, optional
63-
The directory to find the sqlite database file. We query the sqlite database to find the strings
64-
that are allowed.
65-
"""
66-
def __init__(self,
67-
token_indexers: Dict[str, TokenIndexer] = None,
68-
lazy: bool = False,
69-
tokenizer: Tokenizer = None,
70-
database_directory: str = None) -> None:
71-
super().__init__(lazy)
72-
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
73-
self._tokenizer = tokenizer or WordTokenizer(SpacyWordSplitter(pos_tags=True))
74-
self._database_directory = database_directory
75-
76-
@overrides
77-
def _read(self, file_path: str):
78-
# if `file_path` is a URL, redirect to the cache
79-
file_path = cached_path(file_path)
80-
81-
with open(file_path) as atis_file:
82-
logger.info("Reading ATIS instances from dataset at : %s", file_path)
83-
for line in _lazy_parse(atis_file.read()):
84-
utterances = []
85-
for current_interaction in line['interaction']:
86-
if not current_interaction['utterance']:
87-
continue
88-
utterances.append(current_interaction['utterance'])
89-
instance = self.text_to_instance(utterances, current_interaction['sql'])
90-
if not instance:
91-
continue
92-
yield instance
93-
94-
@overrides
95-
def text_to_instance(self, # type: ignore
96-
utterances: List[str],
97-
sql_query: str = None) -> Instance:
98-
# pylint: disable=arguments-differ
99-
"""
100-
Parameters
101-
----------
102-
utterances: ``List[str]``, required.
103-
List of utterances in the interaction, the last element is the current utterance.
104-
sql_query: ``str``, optional
105-
The SQL query, given as label during training or validation.
106-
"""
107-
utterance = utterances[-1]
108-
action_sequence: List[str] = []
109-
110-
if not utterance:
111-
return None
112-
113-
world = AtisWorld(utterances=utterances,
114-
database_directory=self._database_directory)
115-
116-
if sql_query:
117-
try:
118-
action_sequence = world.get_action_sequence(sql_query)
119-
except ParseError:
120-
logger.debug(f'Parsing error')
121-
122-
tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
123-
utterance_field = TextField(tokenized_utterance, self._token_indexers)
124-
125-
production_rule_fields: List[Field] = []
126-
127-
for production_rule in world.all_possible_actions():
128-
lhs, _ = production_rule.split(' ->')
129-
is_global_rule = not lhs in ['number', 'string']
130-
# The whitespaces are not semantically meaningful, so we filter them out.
131-
production_rule = ' '.join([token for token in production_rule.split(' ') if token != 'ws'])
132-
field = ProductionRuleField(production_rule, is_global_rule)
133-
production_rule_fields.append(field)
134-
135-
action_field = ListField(production_rule_fields)
136-
action_map = {action.rule: i # type: ignore
137-
for i, action in enumerate(action_field.field_list)}
138-
index_fields: List[Field] = []
139-
world_field = MetadataField(world)
140-
fields = {'utterance' : utterance_field,
141-
'actions' : action_field,
142-
'world' : world_field,
143-
'linking_scores' : ArrayField(world.linking_scores)}
144-
145-
if sql_query:
146-
if action_sequence:
147-
for production_rule in action_sequence:
148-
index_fields.append(IndexField(action_map[production_rule], action_field))
149-
150-
action_sequence_field: List[Field] = []
151-
action_sequence_field.append(ListField(index_fields))
152-
fields['target_action_sequence'] = ListField(action_sequence_field)
153-
else:
154-
# If we are given a SQL query, but we are unable to parse it, then we will skip it.
155-
return None
156-
157-
return Instance(fields)
5+
warnings.warn("allennlp.data.dataset_readers.atis.* has been moved."
6+
"Please use allennlp.data.dataset_reader.semantic_parsing.atis.*", FutureWarning)

allennlp/data/dataset_readers/dataset_utils/text2sql_utils.py

+49-30
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Utility functions for reading the standardised text2sql datasets presented in
44
`"Improving Text to SQL Evaluation Methodology" <https://arxiv.org/abs/1806.09029>`_
55
"""
6-
from typing import List, Dict, NamedTuple, Iterable
6+
from typing import List, Dict, NamedTuple, Iterable, Tuple, Set
77

88
from allennlp.common import JsonDict
99

@@ -19,6 +19,10 @@ class SqlData(NamedTuple):
1919
text_with_variables : ``List[str]``
2020
The tokens in the text of the query with variables
2121
mapped to table names/abstract variables.
22+
variable_tags : ``List[str]``
23+
Labels for each word in ``text`` which correspond to
24+
which variable in the sql the token is linked to. "O"
25+
is used to denote no tag.
2226
sql : ``List[str]``
2327
The tokens in the SQL query which corresponds to the text.
2428
text_variables : ``Dict[str, str]``
@@ -28,24 +32,28 @@ class SqlData(NamedTuple):
2832
"""
2933
text: List[str]
3034
text_with_variables: List[str]
35+
variable_tags: List[str]
3136
sql: List[str]
3237
text_variables: Dict[str, str]
3338
sql_variables: Dict[str, str]
3439

3540

3641
def replace_variables(sentence: List[str],
37-
sentence_variables: Dict[str, str]) -> List[str]:
42+
sentence_variables: Dict[str, str]) -> Tuple[List[str], List[str]]:
3843
"""
3944
Replaces abstract variables in text with their concrete counterparts.
4045
"""
4146
tokens = []
47+
tags = []
4248
for token in sentence:
4349
if token not in sentence_variables:
4450
tokens.append(token)
51+
tags.append("O")
4552
else:
4653
for word in sentence_variables[token].split():
4754
tokens.append(word)
48-
return tokens
55+
tags.append(token)
56+
return tokens, tags
4957

5058
def clean_and_split_sql(sql: str) -> List[str]:
5159
"""
@@ -63,10 +71,11 @@ def clean_and_split_sql(sql: str) -> List[str]:
6371
return sql_tokens
6472

6573

66-
def process_sql_data_blob(data: JsonDict,
67-
use_all_sql: bool = False) -> Iterable[SqlData]:
74+
def process_sql_data(data: List[JsonDict],
75+
use_all_sql: bool = False,
76+
use_all_queries: bool = False) -> Iterable[SqlData]:
6877
"""
69-
A utility function for reading in text2sql data blobs. The blob is
78+
A utility function for reading in text2sql data. The blob is
7079
the result of loading the json from a file produced by the script
7180
``scripts/reformat_text2sql_data.py``.
7281
@@ -76,32 +85,42 @@ def process_sql_data_blob(data: JsonDict,
7685
use_all_sql : ``bool``, optional (default = False)
7786
Whether to use all of the sql queries which have identical semantics,
7887
or whether to just use the first one.
88+
use_all_queries : ``bool``, (default = False)
89+
Whether or not to enforce query sentence uniqueness. If false,
90+
duplicated queries will occur in the dataset as separate instances,
91+
as for a given SQL query, not only are there multiple queries with
92+
the same template, but there are also duplicate queries.
7993
"""
80-
# TODO(Mark): currently this does not filter for duplicate _sentences_
81-
# which have the same sql query. Really it should, because these instances
82-
# are literally identical, so just magnify errors etc. However, doing this
83-
# would make it really hard to compare to previous work. Sad times.
84-
for sent_info in data['sentences']:
85-
# Loop over the different sql statements with "equivalent" semantics
86-
for sql in data["sql"]:
87-
sql_variables = {}
88-
for variable in data['variables']:
89-
sql_variables[variable['name']] = variable['example']
94+
for example in data:
95+
seen_sentences: Set[str] = set()
96+
for sent_info in example['sentences']:
97+
# Loop over the different sql statements with "equivalent" semantics
98+
for sql in example["sql"]:
99+
text_with_variables = sent_info['text'].strip().split()
100+
text_vars = sent_info['variables']
90101

91-
text_with_variables = sent_info['text'].strip().split()
92-
text_vars = sent_info['variables']
102+
query_tokens, tags = replace_variables(text_with_variables, text_vars)
103+
if not use_all_queries:
104+
key = " ".join(query_tokens)
105+
if key in seen_sentences:
106+
continue
107+
else:
108+
seen_sentences.add(key)
93109

94-
query_tokens = replace_variables(text_with_variables, text_vars)
95-
sql_tokens = clean_and_split_sql(sql)
110+
sql_tokens = clean_and_split_sql(sql)
111+
sql_variables = {}
112+
for variable in example['variables']:
113+
sql_variables[variable['name']] = variable['example']
96114

97-
sql_data = SqlData(text=query_tokens,
98-
text_with_variables=text_with_variables,
99-
sql=sql_tokens,
100-
text_variables=text_vars,
101-
sql_variables=sql_variables)
102-
yield sql_data
115+
sql_data = SqlData(text=query_tokens,
116+
text_with_variables=text_with_variables,
117+
variable_tags=tags,
118+
sql=sql_tokens,
119+
text_variables=text_vars,
120+
sql_variables=sql_variables)
121+
yield sql_data
103122

104-
# Some questions might have multiple equivalent SQL statements.
105-
# By default, we just use the first one. TODO(Mark): Use the shortest?
106-
if not use_all_sql:
107-
break
123+
# Some questions might have multiple equivalent SQL statements.
124+
# By default, we just use the first one. TODO(Mark): Use the shortest?
125+
if not use_all_sql:
126+
break

0 commit comments

Comments
 (0)