Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit a4670ad

Browse files
authored
Grammar induction from a python executor (#2281)
* Got a simple LF executor working where all you need is to specify functions * Grammar induction for simple cases now works * Got logical_form_to_action_sequence working for simple cases * Got action_sequence_to_logical_form working for simple cases * Fix (most?) NLVR and WikiTables tests * Renamed Executor to DomainLanguage * Fix world test * Add some documentation, fix pylint * mypy * Added docstrings, fixed docs * Improve documentation, simplify some logic * Removed the old type_declaration code entirely, better handled constants * mypy * PR feedback * mypy... * Actually fix mypy...
1 parent 088f0bb commit a4670ad

27 files changed

+788
-117
lines changed

allennlp/semparse/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,7 @@
2929
# dependency issues. If you want to import semparse stuff from the data code, just use a more
3030
# complete path, like `from allennlp.semparse.worlds import WikiTablesWorld`.
3131
from allennlp.data.tokenizers import Token as _
32-
from allennlp.semparse.worlds.world import ParsingError, World
32+
from allennlp.semparse.domain_languages.domain_language import (DomainLanguage, ParsingError,
33+
ExecutionError, predicate)
34+
from allennlp.semparse.worlds.world import World
3335
from allennlp.semparse.action_space_walker import ActionSpaceWalker
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from allennlp.semparse.domain_languages.domain_language import (DomainLanguage, ParsingError,
2+
ExecutionError, predicate)

allennlp/semparse/domain_languages/domain_language.py

+474
Large diffs are not rendered by default.

allennlp/semparse/worlds/nlvr_box.py renamed to allennlp/semparse/executors/nlvr_box.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List
22

33
from allennlp.common.util import JsonDict
4-
from allennlp.semparse.worlds.nlvr_object import Object
4+
from allennlp.semparse.executors.nlvr_object import Object
55

66

77
class Box:

allennlp/semparse/executors/nlvr_executor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import logging
55

66
from allennlp.semparse import util as semparse_util
7-
from allennlp.semparse.worlds.world import ExecutionError
8-
from allennlp.semparse.worlds.nlvr_object import Object
9-
from allennlp.semparse.worlds.nlvr_box import Box
7+
from allennlp.semparse.domain_languages.domain_language import ExecutionError
8+
from allennlp.semparse.executors.nlvr_object import Object
9+
from allennlp.semparse.executors.nlvr_box import Box
1010

1111

1212
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@@ -115,7 +115,7 @@ def execute(self, logical_form: str) -> bool:
115115
expression_as_list = semparse_util.lisp_to_nested_expression(logical_form)
116116
# Expression list has an additional level of
117117
# nesting at the top.
118-
result = self._handle_expression(expression_as_list[0])
118+
result = self._handle_expression(expression_as_list)
119119
return result
120120

121121
def _handle_expression(self, expression_list) -> bool:

allennlp/semparse/executors/wikitables_variable_free_executor.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from allennlp.semparse import util as semparse_util
6-
from allennlp.semparse.worlds.world import ExecutionError
6+
from allennlp.semparse.domain_languages.domain_language import ExecutionError
77
from allennlp.semparse.contexts.table_question_knowledge_graph import MONTH_NUMBERS
88
from allennlp.semparse.contexts import TableQuestionContext
99
from allennlp.tools import wikitables_evaluator as evaluator
@@ -90,14 +90,7 @@ def execute(self, logical_form: str) -> Any:
9090
logical_form = f"({logical_form})"
9191
logical_form = logical_form.replace(",", " ")
9292
expression_as_list = semparse_util.lisp_to_nested_expression(logical_form)
93-
# Expression list has an additional level of
94-
# nesting at the top. For example, if the
95-
# logical form is
96-
# "(select all_rows fb:row.row.league)",
97-
# the expression list will be
98-
# [['select', 'all_rows', 'fb:row.row.league']].
99-
# Removing the top most level of nesting.
100-
result = self._handle_expression(expression_as_list[0])
93+
result = self._handle_expression(expression_as_list)
10194
return result
10295

10396
def evaluate_logical_form(self, logical_form: str, target_list: List[str]) -> bool:

allennlp/semparse/type_declarations/nlvr_type_declaration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,5 @@ def substitute_any_type(self, basic_types: Set[BasicType]) -> List[Type]:
170170
name_mapper.map_name_with_signature(name=num_string, signature=NUM_TYPE, alias=num_string)
171171

172172

173-
COMMON_NAME_MAPPING = name_mapper.common_name_mapping
174-
COMMON_TYPE_SIGNATURE = name_mapper.common_type_signature
173+
COMMON_NAME_MAPPING = name_mapper.name_mapping
174+
COMMON_TYPE_SIGNATURE = name_mapper.type_signatures

allennlp/semparse/type_declarations/type_declaration.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -491,9 +491,9 @@ class NameMapper:
491491
Parameters
492492
----------
493493
language_has_lambda : ``bool`` (optional, default=False)
494-
If your language has lambda functions, the word "lambda" needs to be in the common name
495-
mapping, mapped to the alias "\". NLTK understands this symbol, and it doesn't need a type
496-
signature for it. Setting this flag to True adds the mapping to `common_name_mapping`.
494+
If your language has lambda functions, the word "lambda" needs to be in the name mapping,
495+
mapped to the alias "\". NLTK understands this symbol, and it doesn't need a type signature
496+
for it. Setting this flag to True adds the mapping to `name_mapping`.
497497
alias_prefix : ``str`` (optional, default="F")
498498
The one letter prefix used for all aliases. You do not need to specify it if you have only
499499
instance of this class for you language. If not, you can specify a different prefix for each
@@ -502,10 +502,10 @@ class NameMapper:
502502
def __init__(self,
503503
language_has_lambda: bool = False,
504504
alias_prefix: str = "F") -> None:
505-
self.common_name_mapping: Dict[str, str] = {}
505+
self.name_mapping: Dict[str, str] = {}
506506
if language_has_lambda:
507-
self.common_name_mapping["lambda"] = "\\"
508-
self.common_type_signature: Dict[str, Type] = {}
507+
self.name_mapping["lambda"] = "\\"
508+
self.type_signatures: Dict[str, Type] = {}
509509
assert len(alias_prefix) == 1 and alias_prefix.isalpha(), (f"Invalid alias prefix: {alias_prefix}"
510510
"Needs to be a single upper case character.")
511511
self._alias_prefix = alias_prefix.upper()
@@ -515,26 +515,26 @@ def map_name_with_signature(self,
515515
name: str,
516516
signature: Type,
517517
alias: str = None) -> None:
518-
if name in self.common_name_mapping:
519-
alias = self.common_name_mapping[name]
520-
old_signature = self.common_type_signature[alias]
518+
if name in self.name_mapping:
519+
alias = self.name_mapping[name]
520+
old_signature = self.type_signatures[alias]
521521
if old_signature != signature:
522522
raise RuntimeError(f"{name} already added with signature {old_signature}. "
523523
f"Cannot add it again with {signature}!")
524524
else:
525525
alias = alias or f"{self._alias_prefix}{self._name_counter}"
526526
self._name_counter += 1
527-
self.common_name_mapping[name] = alias
528-
self.common_type_signature[alias] = signature
527+
self.name_mapping[name] = alias
528+
self.type_signatures[alias] = signature
529529

530530
def get_alias(self, name: str) -> str:
531-
if name not in self.common_name_mapping:
531+
if name not in self.name_mapping:
532532
raise RuntimeError(f"Unmapped name: {name}")
533-
return self.common_name_mapping[name]
533+
return self.name_mapping[name]
534534

535535
def get_signature(self, name: str) -> Type:
536536
alias = self.get_alias(name)
537-
return self.common_type_signature[alias]
537+
return self.type_signatures[alias]
538538

539539

540540
def substitute_any_type(type_: Type, basic_types: Set[BasicType]) -> List[Type]:

allennlp/semparse/type_declarations/wikitables_lambda_dcs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -244,5 +244,5 @@ def substitute_any_type(self, basic_types: Set[BasicType]) -> List[Type]:
244244
name_mapper.map_name_with_signature("avg", UNARY_NUM_OP_TYPE)
245245
name_mapper.map_name_with_signature("-", BINARY_NUM_OP_TYPE) # subtraction
246246

247-
COMMON_NAME_MAPPING = name_mapper.common_name_mapping
248-
COMMON_TYPE_SIGNATURE = name_mapper.common_type_signature
247+
COMMON_NAME_MAPPING = name_mapper.name_mapping
248+
COMMON_TYPE_SIGNATURE = name_mapper.type_signatures

allennlp/semparse/type_declarations/wikitables_variable_free.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,13 @@
142142
# <n,<n,<n,d>>>
143143
generic_name_mapper.map_name_with_signature("date", DATE_FUNCTION_TYPE)
144144

145-
COMMON_NAME_MAPPING = generic_name_mapper.common_name_mapping
146-
COMMON_TYPE_SIGNATURE = generic_name_mapper.common_type_signature
147-
STRING_COLUMN_NAME_MAPPING = string_column_name_mapper.common_name_mapping
148-
STRING_COLUMN_TYPE_SIGNATURE = string_column_name_mapper.common_type_signature
149-
NUMBER_COLUMN_NAME_MAPPING = number_column_name_mapper.common_name_mapping
150-
NUMBER_COLUMN_TYPE_SIGNATURE = number_column_name_mapper.common_type_signature
151-
DATE_COLUMN_NAME_MAPPING = date_column_name_mapper.common_name_mapping
152-
DATE_COLUMN_TYPE_SIGNATURE = date_column_name_mapper.common_type_signature
153-
COMPARABLE_COLUMN_NAME_MAPPING = comparable_column_name_mapper.common_name_mapping
154-
COMPARABLE_COLUMN_TYPE_SIGNATURE = comparable_column_name_mapper.common_type_signature
145+
COMMON_NAME_MAPPING = generic_name_mapper.name_mapping
146+
COMMON_TYPE_SIGNATURE = generic_name_mapper.type_signatures
147+
STRING_COLUMN_NAME_MAPPING = string_column_name_mapper.name_mapping
148+
STRING_COLUMN_TYPE_SIGNATURE = string_column_name_mapper.type_signatures
149+
NUMBER_COLUMN_NAME_MAPPING = number_column_name_mapper.name_mapping
150+
NUMBER_COLUMN_TYPE_SIGNATURE = number_column_name_mapper.type_signatures
151+
DATE_COLUMN_NAME_MAPPING = date_column_name_mapper.name_mapping
152+
DATE_COLUMN_TYPE_SIGNATURE = date_column_name_mapper.type_signatures
153+
COMPARABLE_COLUMN_NAME_MAPPING = comparable_column_name_mapper.name_mapping
154+
COMPARABLE_COLUMN_TYPE_SIGNATURE = comparable_column_name_mapper.type_signatures

allennlp/semparse/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def lisp_to_nested_expression(lisp_string: str) -> List:
2020
while token[-1] == ')':
2121
current_expression = stack.pop()
2222
token = token[:-1]
23-
return current_expression
23+
return current_expression[0]

allennlp/semparse/worlds/nlvr_world.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from allennlp.common.util import JsonDict
1212
from allennlp.semparse.type_declarations import nlvr_type_declaration as types
13-
from allennlp.semparse.worlds.nlvr_box import Box
13+
from allennlp.semparse.executors.nlvr_box import Box
1414
from allennlp.semparse.worlds.world import World
1515
from allennlp.semparse.executors import NlvrExecutor
1616

allennlp/semparse/worlds/quarel_world.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def __init__(self,
2727

2828
self._syntax = syntax
2929
self.types = QuarelTypeDeclaration(syntax)
30-
super(QuarelWorld, self).__init__(
31-
global_type_signatures=self.types.name_mapper.common_type_signature,
32-
global_name_mapping=self.types.name_mapper.common_name_mapping)
30+
super().__init__(
31+
global_type_signatures=self.types.name_mapper.type_signatures,
32+
global_name_mapping=self.types.name_mapper.name_mapping)
3333
self.table_graph = table_graph
3434

3535
# Keep map and counter for each entity type encountered (first letter in entity string)
@@ -69,8 +69,8 @@ def _entity_index(self, entity) -> int:
6969
@overrides
7070
def _map_name(self, name: str, keep_mapping: bool = False) -> str:
7171
translated_name = name
72-
if name in self.types.name_mapper.common_name_mapping:
73-
translated_name = self.types.name_mapper.common_name_mapping[name]
72+
if name in self.types.name_mapper.name_mapping:
73+
translated_name = self.types.name_mapper.name_mapping[name]
7474
elif name in self.local_name_mapping:
7575
translated_name = self.local_name_mapping[name]
7676
elif name.startswith("a:"):

allennlp/semparse/worlds/world.py

+2-46
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,13 @@
77
from nltk.sem.logic import ApplicationExpression, Expression, LambdaExpression, BasicType, Type
88

99
from allennlp.semparse.type_declarations import type_declaration as types
10+
from allennlp.semparse.domain_languages.domain_language import ParsingError
11+
from allennlp.semparse.domain_languages.domain_language import nltk_tree_to_logical_form
1012
from allennlp.semparse import util as semparse_util
1113

1214
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
1315

1416

15-
class ParsingError(Exception):
16-
"""
17-
This exception gets raised when there is a parsing error during logical form processing. This
18-
might happen because you're not handling the full set of possible logical forms, for instance,
19-
and having this error provides a consistent way to catch those errors and log how frequently
20-
this occurs.
21-
"""
22-
def __init__(self, message):
23-
super(ParsingError, self).__init__()
24-
self.message = message
25-
26-
def __str__(self):
27-
return repr(self.message)
28-
29-
30-
class ExecutionError(Exception):
31-
"""
32-
This exception gets raised when you're trying to execute a logical form that your executor does
33-
not understand. This may be because your logical form contains a function with an invalid name
34-
or a set of arguments whose types do not match those that the function expects.
35-
"""
36-
def __init__(self, message):
37-
super(ExecutionError, self).__init__()
38-
self.message = message
39-
40-
def __str__(self):
41-
return repr(self.message)
42-
43-
44-
def nltk_tree_to_logical_form(tree: Tree) -> str:
45-
"""
46-
Given an ``nltk.Tree`` representing the syntax tree that generates a logical form, this method
47-
produces the actual (lisp-like) logical form, with all of the non-terminal symbols converted
48-
into the correct number of parentheses.
49-
"""
50-
# nltk.Tree actually inherits from `list`, so you use `len()` to get the number of children.
51-
# We're going to be explicit about checking length, instead of using `if tree:`, just to avoid
52-
# any funny business nltk might have done (e.g., it's really odd if `if tree:` evaluates to
53-
# `False` if there's a single leaf node with no children).
54-
if len(tree) == 0: # pylint: disable=len-as-condition
55-
return tree.label()
56-
if len(tree) == 1:
57-
return tree[0].label()
58-
return '(' + ' '.join(nltk_tree_to_logical_form(child) for child in tree) + ')'
59-
60-
6117
class World:
6218
"""
6319
Base class for defining a world in a new domain. This class defines a method to translate a

allennlp/tests/semparse/domain_languages/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)