|
| 1 | +import sys |
1 | 2 | import time
|
2 | 3 | import argparse
|
3 | 4 |
|
4 | 5 | from scienceworld import ScienceWorldEnv
|
5 | 6 |
|
6 | 7 |
|
| 8 | +prompt_toolkit_available = False |
| 9 | +try: |
| 10 | + # For command line history and autocompletion. |
| 11 | + from prompt_toolkit import prompt |
| 12 | + from prompt_toolkit.completion import WordCompleter |
| 13 | + from prompt_toolkit.history import InMemoryHistory |
| 14 | + prompt_toolkit_available = sys.stdout.isatty() |
| 15 | +except ImportError: |
| 16 | + pass |
| 17 | + |
| 18 | +try: |
| 19 | + # For command line history when prompt_toolkit is not available. |
| 20 | + import readline # noqa: F401 |
| 21 | +except ImportError: |
| 22 | + pass |
| 23 | + |
| 24 | + |
7 | 25 | def userConsole(args):
|
8 | 26 | """ Example user input console, to play through a game. """
|
| 27 | + history = None |
| 28 | + if prompt_toolkit_available: |
| 29 | + history = InMemoryHistory() |
| 30 | + |
9 | 31 | exitCommands = ["quit", "exit"]
|
10 | 32 |
|
11 | 33 | taskIdx = args['task_num']
|
@@ -98,13 +120,22 @@ def userConsole(args):
|
98 | 120 | print("isCompleted: " + str(isCompleted))
|
99 | 121 | # print("info: " + str(info))
|
100 | 122 |
|
101 |
| - print("'help' lists valid action templates, 'objects' lists valid" + |
102 |
| - " objects, 'valid' lists valid action-object combinations (long!). ") |
| 123 | + print("'help' lists valid action templates, 'objects' lists valid objects, use <tab> to list valid actions. ") |
103 | 124 | print("'goals' lists progress on subgoals.")
|
104 | 125 | print("type 'exit' to quit.")
|
105 | 126 |
|
| 127 | + # Select a random action |
| 128 | + valid_actions = env.get_valid_action_object_combinations() |
| 129 | + |
106 | 130 | # Get user input
|
107 |
| - userInputStr = input('> ') |
| 131 | + if prompt_toolkit_available: |
| 132 | + actions_completer = WordCompleter(valid_actions, ignore_case=True, sentence=True) |
| 133 | + userInputStr = prompt('> ', completer=actions_completer, |
| 134 | + history=history, enable_history_search=True) |
| 135 | + else: |
| 136 | + print("Valid Actions: " + str(valid_actions)) |
| 137 | + userInputStr = input('> ') |
| 138 | + |
108 | 139 | # Sanitize input
|
109 | 140 | userInputStr = userInputStr.lower().strip()
|
110 | 141 |
|
|
0 commit comments