Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add basic support for prompt-toolkit in the CLI #7709

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 98 additions & 21 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,9 @@
import sys
from uuid import uuid4

from termcolor import colored
from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.key_binding import KeyBindings

import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.core.config import (
@@ -36,24 +38,66 @@
CmdOutputObservation,
FileEditObservation,
)
from openhands.io import read_input, read_task
from openhands.io import read_task

prompt_session = PromptSession()


def display_message(message: str):
print(colored('🤖 ' + message + '\n', 'yellow'))
print_formatted_text(
FormattedText(
[
('ansiyellow', '🤖 '),
('ansiyellow', message),
('', '\n'),
]
)
)


def display_command(command: str):
print('❯ ' + colored(command + '\n', 'green'))
print_formatted_text(
FormattedText(
[
('', '❯ '),
('ansigreen', command),
('', '\n'),
]
)
)


def display_confirmation(confirmation_state: ActionConfirmationStatus):
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
print(colored('✅ ' + confirmation_state + '\n', 'green'))
print_formatted_text(
FormattedText(
[
('ansigreen', '✅ '),
('ansigreen', str(confirmation_state)),
('', '\n'),
]
)
)
elif confirmation_state == ActionConfirmationStatus.REJECTED:
print(colored('❌ ' + confirmation_state + '\n', 'red'))
print_formatted_text(
FormattedText(
[
('ansired', '❌ '),
('ansired', str(confirmation_state)),
('', '\n'),
]
)
)
else:
print(colored('⏳ ' + confirmation_state + '\n', 'yellow'))
print_formatted_text(
FormattedText(
[
('ansiyellow', '⏳ '),
('ansiyellow', str(confirmation_state)),
('', '\n'),
]
)
)


def display_command_output(output: str):
@@ -62,12 +106,19 @@ def display_command_output(output: str):
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
# TODO: clean this up once we clean up terminal output
continue
print(colored(line, 'blue'))
print('\n')
print_formatted_text(FormattedText([('ansiblue', line)]))
print_formatted_text('')


def display_file_edit(event: FileEditAction | FileEditObservation):
print(colored(str(event), 'green'))
print_formatted_text(
FormattedText(
[
('ansigreen', str(event)),
('', '\n'),
]
)
)


def display_event(event: Event, config: AppConfig):
@@ -89,6 +140,41 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)


async def read_prompt_input(multiline=False):
try:
if multiline:
kb = KeyBindings()

@kb.add('c-d')
def _(event):
event.current_buffer.validate_and_handle()

message = await prompt_session.prompt_async(
'Enter your message and press Ctrl+D to finish:\n',
multiline=True,
key_bindings=kb,
)
else:
message = await prompt_session.prompt_async(
'>> ',
)
return message
except KeyboardInterrupt:
return 'exit'
except EOFError:
return 'exit'


async def read_confirmation_input():
try:
confirmation = await prompt_session.prompt_async(
'Confirm action (possible security risk)? (y/n) >> ',
)
return confirmation.lower() == 'y'
except (KeyboardInterrupt, EOFError):
return False


async def main(loop: asyncio.AbstractEventLoop):
"""Runs the agent in CLI mode."""

@@ -122,10 +208,7 @@ async def main(loop: asyncio.AbstractEventLoop):
event_stream = runtime.event_stream

async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
next_message = await loop.run_in_executor(
None, read_input, config.cli_multiline_input
)
next_message = await read_prompt_input(config.cli_multiline_input)
if not next_message.strip():
await prompt_for_next_task()
if next_message == 'exit':
@@ -136,12 +219,6 @@ async def prompt_for_next_task():
action = MessageAction(content=next_message)
event_stream.add_event(action, EventSource.USER)

async def prompt_for_user_confirmation():
user_confirmation = await loop.run_in_executor(
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
)
return user_confirmation.lower() == 'y'

async def on_event_async(event: Event):
display_event(event, config)
if isinstance(event, AgentStateChangedObservation):
@@ -151,7 +228,7 @@ async def on_event_async(event: Event):
]:
await prompt_for_next_task()
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
user_confirmed = await prompt_for_user_confirmation()
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -79,6 +79,7 @@ memory-profiler = "^0.61.0"
daytona-sdk = "0.12.1"
python-json-logger = "^3.2.1"
playwright = "^1.51.0"
prompt-toolkit = "^3.0.50"

[tool.poetry.group.dev.dependencies]
ruff = "0.11.4"
169 changes: 169 additions & 0 deletions tests/unit/test_cli_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import asyncio
from datetime import datetime
from io import StringIO
from unittest.mock import AsyncMock, Mock, patch

import pytest
from prompt_toolkit.application import create_app_session
from prompt_toolkit.input import create_pipe_input
from prompt_toolkit.output import create_output

from openhands.core.cli import main
from openhands.core.config import AppConfig
from openhands.core.schema import AgentState
from openhands.events.action import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation import AgentStateChangedObservation


class MockEventStream:
def __init__(self):
self._subscribers = {}
self.cur_id = 0

def subscribe(self, subscriber_id, callback, callback_id):
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._subscribers[subscriber_id][callback_id] = callback

def unsubscribe(self, subscriber_id, callback_id):
if (
subscriber_id in self._subscribers
and callback_id in self._subscribers[subscriber_id]
):
del self._subscribers[subscriber_id][callback_id]

def add_event(self, event, source):
event._id = self.cur_id
self.cur_id += 1
event._source = source
event._timestamp = datetime.now().isoformat()

for subscriber_id in self._subscribers:
for callback_id, callback in self._subscribers[subscriber_id].items():
callback(event)


@pytest.fixture
def mock_agent():
with patch('openhands.core.cli.create_agent') as mock_create_agent:
mock_agent_instance = AsyncMock()
mock_agent_instance.name = 'test-agent'
mock_agent_instance.llm = AsyncMock()
mock_agent_instance.llm.config = AsyncMock()
mock_agent_instance.llm.config.model = 'test-model'
mock_agent_instance.llm.config.base_url = 'http://test'
mock_agent_instance.llm.config.max_message_chars = 1000
mock_agent_instance.config = AsyncMock()
mock_agent_instance.config.disabled_microagents = []
mock_agent_instance.sandbox_plugins = []
mock_agent_instance.prompt_manager = AsyncMock()
mock_create_agent.return_value = mock_agent_instance
yield mock_agent_instance


@pytest.fixture
def mock_controller():
with patch('openhands.core.cli.create_controller') as mock_create_controller:
mock_controller_instance = AsyncMock()
mock_controller_instance.state.agent_state = None
# Mock run_until_done to finish immediately
mock_controller_instance.run_until_done = AsyncMock(return_value=None)
mock_create_controller.return_value = (mock_controller_instance, None)
yield mock_controller_instance


@pytest.fixture
def mock_config():
with patch('openhands.core.cli.parse_arguments') as mock_parse_args:
args = Mock()
args.file = None
args.task = None
args.directory = None
mock_parse_args.return_value = args
with patch('openhands.core.cli.setup_config_from_args') as mock_setup_config:
mock_config = AppConfig()
mock_config.cli_multiline_input = False
mock_config.security = Mock()
mock_config.security.confirmation_mode = False
mock_config.sandbox = Mock()
mock_config.sandbox.selected_repo = None
mock_setup_config.return_value = mock_config
yield mock_config


@pytest.fixture
def mock_memory():
with patch('openhands.core.cli.create_memory') as mock_create_memory:
mock_memory_instance = AsyncMock()
mock_create_memory.return_value = mock_memory_instance
yield mock_memory_instance


@pytest.fixture
def mock_read_task():
with patch('openhands.core.cli.read_task') as mock_read_task:
mock_read_task.return_value = None
yield mock_read_task


@pytest.fixture
def mock_runtime():
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
mock_runtime_instance = AsyncMock()

mock_event_stream = MockEventStream()
mock_runtime_instance.event_stream = mock_event_stream

mock_runtime_instance.connect = AsyncMock()

# Ensure status_callback is None
mock_runtime_instance.status_callback = None
# Mock get_microagents_from_selected_repo
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
mock_create_runtime.return_value = mock_runtime_instance
yield mock_runtime_instance


@pytest.mark.asyncio
async def test_cli_greeting(
mock_runtime, mock_controller, mock_config, mock_agent, mock_memory, mock_read_task
):
buffer = StringIO()

with create_app_session(
input=create_pipe_input(), output=create_output(stdout=buffer)
):
mock_controller.status_callback = None

main_task = asyncio.create_task(main(asyncio.get_event_loop()))

await asyncio.sleep(0.1)

hello_response = MessageAction(content='Ping')
hello_response._source = EventSource.AGENT
mock_runtime.event_stream.add_event(hello_response, EventSource.AGENT)

state_change = AgentStateChangedObservation(
content='Awaiting user input', agent_state=AgentState.AWAITING_USER_INPUT
)
state_change._source = EventSource.AGENT
mock_runtime.event_stream.add_event(state_change, EventSource.AGENT)

stop_event = AgentStateChangedObservation(
content='Stop', agent_state=AgentState.STOPPED
)
stop_event._source = EventSource.AGENT
mock_runtime.event_stream.add_event(stop_event, EventSource.AGENT)

mock_controller.state.agent_state = AgentState.STOPPED

try:
await asyncio.wait_for(main_task, timeout=1.0)
except asyncio.TimeoutError:
main_task.cancel()

buffer.seek(0)
output = buffer.read()

assert 'Ping' in output
61 changes: 36 additions & 25 deletions tests/unit/test_cli_sid.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
from argparse import Namespace
from io import StringIO
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch

import pytest
from prompt_toolkit.application import create_app_session
from prompt_toolkit.input import create_pipe_input
from prompt_toolkit.output import create_output

from openhands.core.cli import main
from openhands.core.config import AppConfig
@@ -83,34 +87,41 @@ def mock_config(task_file: Path):

@pytest.mark.asyncio
async def test_cli_session_id_output(
mock_runtime, mock_agent, mock_controller, mock_config, capsys
mock_runtime, mock_agent, mock_controller, mock_config
):
# status_callback is set when initializing the runtime
mock_controller.status_callback = None

buffer = StringIO()

# Use input patch just for the exit command
with patch('builtins.input', return_value='exit'):
# Create a task for main
main_task = asyncio.create_task(main(asyncio.get_event_loop()))

# Give it a moment to display the session ID
await asyncio.sleep(0.1)

# Trigger agent state change to STOPPED to end the main loop
event = AgentStateChangedObservation(
content='Stop', agent_state=AgentState.STOPPED
)
event._source = EventSource.AGENT
await mock_runtime.event_stream.add_event(event)

# Wait for main to finish with a timeout
try:
await asyncio.wait_for(main_task, timeout=1.0)
except asyncio.TimeoutError:
main_task.cancel()

# Check the output
captured = capsys.readouterr()
assert 'Session ID:' in captured.out
# Also verify that our task message was processed
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)
with create_app_session(
input=create_pipe_input(), output=create_output(stdout=buffer)
):
# Create a task for main
main_task = asyncio.create_task(main(asyncio.get_event_loop()))

# Give it a moment to display the session ID
await asyncio.sleep(0.1)

# Trigger agent state change to STOPPED to end the main loop
event = AgentStateChangedObservation(
content='Stop', agent_state=AgentState.STOPPED
)
event._source = EventSource.AGENT
await mock_runtime.event_stream.add_event(event)

# Wait for main to finish with a timeout
try:
await asyncio.wait_for(main_task, timeout=1.0)
except asyncio.TimeoutError:
main_task.cancel()

buffer.seek(0)
output = buffer.read()

# Check the output
assert 'Session ID:' in output
# Also verify that our task message was processed
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)