Skip to content

Commit 60e8b58

Browse files
pandukamudithaopenhands-agentbashwara
authored
feat: Add basic support for prompt-toolkit in the CLI (#7709)
Co-authored-by: openhands <[email protected]> Co-authored-by: Bashwara Undupitiya <[email protected]>
1 parent dd03d9a commit 60e8b58

File tree

5 files changed

+308
-50
lines changed

5 files changed

+308
-50
lines changed

openhands/core/cli.py

Lines changed: 98 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import sys
44
from uuid import uuid4
55

6-
from termcolor import colored
6+
from prompt_toolkit import PromptSession, print_formatted_text
7+
from prompt_toolkit.formatted_text import FormattedText
8+
from prompt_toolkit.key_binding import KeyBindings
79

810
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
911
from openhands.core.config import (
@@ -36,24 +38,66 @@
3638
CmdOutputObservation,
3739
FileEditObservation,
3840
)
39-
from openhands.io import read_input, read_task
41+
from openhands.io import read_task
42+
43+
prompt_session = PromptSession()
4044

4145

4246
def display_message(message: str):
43-
print(colored('🤖 ' + message + '\n', 'yellow'))
47+
print_formatted_text(
48+
FormattedText(
49+
[
50+
('ansiyellow', '🤖 '),
51+
('ansiyellow', message),
52+
('', '\n'),
53+
]
54+
)
55+
)
4456

4557

4658
def display_command(command: str):
47-
print('❯ ' + colored(command + '\n', 'green'))
59+
print_formatted_text(
60+
FormattedText(
61+
[
62+
('', '❯ '),
63+
('ansigreen', command),
64+
('', '\n'),
65+
]
66+
)
67+
)
4868

4969

5070
def display_confirmation(confirmation_state: ActionConfirmationStatus):
5171
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
52-
print(colored('✅ ' + confirmation_state + '\n', 'green'))
72+
print_formatted_text(
73+
FormattedText(
74+
[
75+
('ansigreen', '✅ '),
76+
('ansigreen', str(confirmation_state)),
77+
('', '\n'),
78+
]
79+
)
80+
)
5381
elif confirmation_state == ActionConfirmationStatus.REJECTED:
54-
print(colored('❌ ' + confirmation_state + '\n', 'red'))
82+
print_formatted_text(
83+
FormattedText(
84+
[
85+
('ansired', '❌ '),
86+
('ansired', str(confirmation_state)),
87+
('', '\n'),
88+
]
89+
)
90+
)
5591
else:
56-
print(colored('⏳ ' + confirmation_state + '\n', 'yellow'))
92+
print_formatted_text(
93+
FormattedText(
94+
[
95+
('ansiyellow', '⏳ '),
96+
('ansiyellow', str(confirmation_state)),
97+
('', '\n'),
98+
]
99+
)
100+
)
57101

58102

59103
def display_command_output(output: str):
@@ -62,12 +106,19 @@ def display_command_output(output: str):
62106
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
63107
# TODO: clean this up once we clean up terminal output
64108
continue
65-
print(colored(line, 'blue'))
66-
print('\n')
109+
print_formatted_text(FormattedText([('ansiblue', line)]))
110+
print_formatted_text('')
67111

68112

69113
def display_file_edit(event: FileEditAction | FileEditObservation):
70-
print(colored(str(event), 'green'))
114+
print_formatted_text(
115+
FormattedText(
116+
[
117+
('ansigreen', str(event)),
118+
('', '\n'),
119+
]
120+
)
121+
)
71122

72123

73124
def display_event(event: Event, config: AppConfig):
@@ -89,6 +140,41 @@ def display_event(event: Event, config: AppConfig):
89140
display_confirmation(event.confirmation_state)
90141

91142

143+
async def read_prompt_input(multiline=False):
144+
try:
145+
if multiline:
146+
kb = KeyBindings()
147+
148+
@kb.add('c-d')
149+
def _(event):
150+
event.current_buffer.validate_and_handle()
151+
152+
message = await prompt_session.prompt_async(
153+
'Enter your message and press Ctrl+D to finish:\n',
154+
multiline=True,
155+
key_bindings=kb,
156+
)
157+
else:
158+
message = await prompt_session.prompt_async(
159+
'>> ',
160+
)
161+
return message
162+
except KeyboardInterrupt:
163+
return 'exit'
164+
except EOFError:
165+
return 'exit'
166+
167+
168+
async def read_confirmation_input():
169+
try:
170+
confirmation = await prompt_session.prompt_async(
171+
'Confirm action (possible security risk)? (y/n) >> ',
172+
)
173+
return confirmation.lower() == 'y'
174+
except (KeyboardInterrupt, EOFError):
175+
return False
176+
177+
92178
async def main(loop: asyncio.AbstractEventLoop):
93179
"""Runs the agent in CLI mode."""
94180

@@ -122,10 +208,7 @@ async def main(loop: asyncio.AbstractEventLoop):
122208
event_stream = runtime.event_stream
123209

124210
async def prompt_for_next_task():
125-
# Run input() in a thread pool to avoid blocking the event loop
126-
next_message = await loop.run_in_executor(
127-
None, read_input, config.cli_multiline_input
128-
)
211+
next_message = await read_prompt_input(config.cli_multiline_input)
129212
if not next_message.strip():
130213
await prompt_for_next_task()
131214
if next_message == 'exit':
@@ -136,12 +219,6 @@ async def prompt_for_next_task():
136219
action = MessageAction(content=next_message)
137220
event_stream.add_event(action, EventSource.USER)
138221

139-
async def prompt_for_user_confirmation():
140-
user_confirmation = await loop.run_in_executor(
141-
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
142-
)
143-
return user_confirmation.lower() == 'y'
144-
145222
async def on_event_async(event: Event):
146223
display_event(event, config)
147224
if isinstance(event, AgentStateChangedObservation):
@@ -151,7 +228,7 @@ async def on_event_async(event: Event):
151228
]:
152229
await prompt_for_next_task()
153230
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
154-
user_confirmed = await prompt_for_user_confirmation()
231+
user_confirmed = await read_confirmation_input()
155232
if user_confirmed:
156233
event_stream.add_event(
157234
ChangeAgentStateAction(AgentState.USER_CONFIRMED),

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ memory-profiler = "^0.61.0"
7979
daytona-sdk = "0.12.1"
8080
python-json-logger = "^3.2.1"
8181
playwright = "^1.51.0"
82+
prompt-toolkit = "^3.0.50"
8283

8384
[tool.poetry.group.dev.dependencies]
8485
ruff = "0.11.4"

tests/unit/test_cli_basic.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import asyncio
2+
from datetime import datetime
3+
from io import StringIO
4+
from unittest.mock import AsyncMock, Mock, patch
5+
6+
import pytest
7+
from prompt_toolkit.application import create_app_session
8+
from prompt_toolkit.input import create_pipe_input
9+
from prompt_toolkit.output import create_output
10+
11+
from openhands.core.cli import main
12+
from openhands.core.config import AppConfig
13+
from openhands.core.schema import AgentState
14+
from openhands.events.action import MessageAction
15+
from openhands.events.event import EventSource
16+
from openhands.events.observation import AgentStateChangedObservation
17+
18+
19+
class MockEventStream:
20+
def __init__(self):
21+
self._subscribers = {}
22+
self.cur_id = 0
23+
24+
def subscribe(self, subscriber_id, callback, callback_id):
25+
if subscriber_id not in self._subscribers:
26+
self._subscribers[subscriber_id] = {}
27+
self._subscribers[subscriber_id][callback_id] = callback
28+
29+
def unsubscribe(self, subscriber_id, callback_id):
30+
if (
31+
subscriber_id in self._subscribers
32+
and callback_id in self._subscribers[subscriber_id]
33+
):
34+
del self._subscribers[subscriber_id][callback_id]
35+
36+
def add_event(self, event, source):
37+
event._id = self.cur_id
38+
self.cur_id += 1
39+
event._source = source
40+
event._timestamp = datetime.now().isoformat()
41+
42+
for subscriber_id in self._subscribers:
43+
for callback_id, callback in self._subscribers[subscriber_id].items():
44+
callback(event)
45+
46+
47+
@pytest.fixture
48+
def mock_agent():
49+
with patch('openhands.core.cli.create_agent') as mock_create_agent:
50+
mock_agent_instance = AsyncMock()
51+
mock_agent_instance.name = 'test-agent'
52+
mock_agent_instance.llm = AsyncMock()
53+
mock_agent_instance.llm.config = AsyncMock()
54+
mock_agent_instance.llm.config.model = 'test-model'
55+
mock_agent_instance.llm.config.base_url = 'http://test'
56+
mock_agent_instance.llm.config.max_message_chars = 1000
57+
mock_agent_instance.config = AsyncMock()
58+
mock_agent_instance.config.disabled_microagents = []
59+
mock_agent_instance.sandbox_plugins = []
60+
mock_agent_instance.prompt_manager = AsyncMock()
61+
mock_create_agent.return_value = mock_agent_instance
62+
yield mock_agent_instance
63+
64+
65+
@pytest.fixture
66+
def mock_controller():
67+
with patch('openhands.core.cli.create_controller') as mock_create_controller:
68+
mock_controller_instance = AsyncMock()
69+
mock_controller_instance.state.agent_state = None
70+
# Mock run_until_done to finish immediately
71+
mock_controller_instance.run_until_done = AsyncMock(return_value=None)
72+
mock_create_controller.return_value = (mock_controller_instance, None)
73+
yield mock_controller_instance
74+
75+
76+
@pytest.fixture
77+
def mock_config():
78+
with patch('openhands.core.cli.parse_arguments') as mock_parse_args:
79+
args = Mock()
80+
args.file = None
81+
args.task = None
82+
args.directory = None
83+
mock_parse_args.return_value = args
84+
with patch('openhands.core.cli.setup_config_from_args') as mock_setup_config:
85+
mock_config = AppConfig()
86+
mock_config.cli_multiline_input = False
87+
mock_config.security = Mock()
88+
mock_config.security.confirmation_mode = False
89+
mock_config.sandbox = Mock()
90+
mock_config.sandbox.selected_repo = None
91+
mock_setup_config.return_value = mock_config
92+
yield mock_config
93+
94+
95+
@pytest.fixture
96+
def mock_memory():
97+
with patch('openhands.core.cli.create_memory') as mock_create_memory:
98+
mock_memory_instance = AsyncMock()
99+
mock_create_memory.return_value = mock_memory_instance
100+
yield mock_memory_instance
101+
102+
103+
@pytest.fixture
104+
def mock_read_task():
105+
with patch('openhands.core.cli.read_task') as mock_read_task:
106+
mock_read_task.return_value = None
107+
yield mock_read_task
108+
109+
110+
@pytest.fixture
111+
def mock_runtime():
112+
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
113+
mock_runtime_instance = AsyncMock()
114+
115+
mock_event_stream = MockEventStream()
116+
mock_runtime_instance.event_stream = mock_event_stream
117+
118+
mock_runtime_instance.connect = AsyncMock()
119+
120+
# Ensure status_callback is None
121+
mock_runtime_instance.status_callback = None
122+
# Mock get_microagents_from_selected_repo
123+
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
124+
mock_create_runtime.return_value = mock_runtime_instance
125+
yield mock_runtime_instance
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_cli_greeting(
130+
mock_runtime, mock_controller, mock_config, mock_agent, mock_memory, mock_read_task
131+
):
132+
buffer = StringIO()
133+
134+
with create_app_session(
135+
input=create_pipe_input(), output=create_output(stdout=buffer)
136+
):
137+
mock_controller.status_callback = None
138+
139+
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
140+
141+
await asyncio.sleep(0.1)
142+
143+
hello_response = MessageAction(content='Ping')
144+
hello_response._source = EventSource.AGENT
145+
mock_runtime.event_stream.add_event(hello_response, EventSource.AGENT)
146+
147+
state_change = AgentStateChangedObservation(
148+
content='Awaiting user input', agent_state=AgentState.AWAITING_USER_INPUT
149+
)
150+
state_change._source = EventSource.AGENT
151+
mock_runtime.event_stream.add_event(state_change, EventSource.AGENT)
152+
153+
stop_event = AgentStateChangedObservation(
154+
content='Stop', agent_state=AgentState.STOPPED
155+
)
156+
stop_event._source = EventSource.AGENT
157+
mock_runtime.event_stream.add_event(stop_event, EventSource.AGENT)
158+
159+
mock_controller.state.agent_state = AgentState.STOPPED
160+
161+
try:
162+
await asyncio.wait_for(main_task, timeout=1.0)
163+
except asyncio.TimeoutError:
164+
main_task.cancel()
165+
166+
buffer.seek(0)
167+
output = buffer.read()
168+
169+
assert 'Ping' in output

0 commit comments

Comments
 (0)