Skip to content

Commit cd64d33

Browse files
pandukamudithabashwara
authored andcommitted
Test cases for CLI
- Add basic test case for CLI - Update existing test case for SID check Co-authored-by: Bashwara Undupitiya <[email protected]>
1 parent 657fade commit cd64d33

File tree

2 files changed

+205
-25
lines changed

2 files changed

+205
-25
lines changed

tests/unit/test_cli_basic.py

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import asyncio
2+
from datetime import datetime
3+
from io import StringIO
4+
from unittest.mock import AsyncMock, Mock, patch
5+
6+
import pytest
7+
from prompt_toolkit.application import create_app_session
8+
from prompt_toolkit.input import create_pipe_input
9+
from prompt_toolkit.output import create_output
10+
11+
from openhands.core.cli import main
12+
from openhands.core.config import AppConfig
13+
from openhands.core.schema import AgentState
14+
from openhands.events.action import MessageAction
15+
from openhands.events.event import EventSource
16+
from openhands.events.observation import AgentStateChangedObservation
17+
18+
19+
class MockEventStream:
20+
def __init__(self):
21+
self._subscribers = {}
22+
self.cur_id = 0
23+
24+
def subscribe(self, subscriber_id, callback, callback_id):
25+
if subscriber_id not in self._subscribers:
26+
self._subscribers[subscriber_id] = {}
27+
self._subscribers[subscriber_id][callback_id] = callback
28+
29+
def unsubscribe(self, subscriber_id, callback_id):
30+
if (
31+
subscriber_id in self._subscribers
32+
and callback_id in self._subscribers[subscriber_id]
33+
):
34+
del self._subscribers[subscriber_id][callback_id]
35+
36+
def add_event(self, event, source):
37+
event._id = self.cur_id
38+
self.cur_id += 1
39+
event._source = source
40+
event._timestamp = datetime.now().isoformat()
41+
42+
for subscriber_id in self._subscribers:
43+
for callback_id, callback in self._subscribers[subscriber_id].items():
44+
callback(event)
45+
46+
47+
@pytest.fixture
48+
def mock_agent():
49+
with patch('openhands.core.cli.create_agent') as mock_create_agent:
50+
mock_agent_instance = AsyncMock()
51+
mock_agent_instance.name = 'test-agent'
52+
mock_agent_instance.llm = AsyncMock()
53+
mock_agent_instance.llm.config = AsyncMock()
54+
mock_agent_instance.llm.config.model = 'test-model'
55+
mock_agent_instance.llm.config.base_url = 'http://test'
56+
mock_agent_instance.llm.config.max_message_chars = 1000
57+
mock_agent_instance.config = AsyncMock()
58+
mock_agent_instance.config.disabled_microagents = []
59+
mock_agent_instance.sandbox_plugins = []
60+
mock_agent_instance.prompt_manager = AsyncMock()
61+
mock_create_agent.return_value = mock_agent_instance
62+
yield mock_agent_instance
63+
64+
65+
@pytest.fixture
66+
def mock_controller():
67+
with patch('openhands.core.cli.create_controller') as mock_create_controller:
68+
mock_controller_instance = AsyncMock()
69+
mock_controller_instance.state.agent_state = None
70+
# Mock run_until_done to finish immediately
71+
mock_controller_instance.run_until_done = AsyncMock(return_value=None)
72+
mock_create_controller.return_value = (mock_controller_instance, None)
73+
yield mock_controller_instance
74+
75+
76+
@pytest.fixture
77+
def mock_config():
78+
with patch('openhands.core.cli.parse_arguments') as mock_parse_args:
79+
args = Mock()
80+
args.file = None
81+
args.task = None
82+
args.directory = None
83+
mock_parse_args.return_value = args
84+
with patch('openhands.core.cli.setup_config_from_args') as mock_setup_config:
85+
mock_config = AppConfig()
86+
mock_config.cli_multiline_input = False
87+
mock_config.security = Mock()
88+
mock_config.security.confirmation_mode = False
89+
mock_config.sandbox = Mock()
90+
mock_config.sandbox.selected_repo = None
91+
mock_setup_config.return_value = mock_config
92+
yield mock_config
93+
94+
95+
@pytest.fixture
96+
def mock_memory():
97+
with patch('openhands.core.cli.create_memory') as mock_create_memory:
98+
mock_memory_instance = AsyncMock()
99+
mock_create_memory.return_value = mock_memory_instance
100+
yield mock_memory_instance
101+
102+
103+
@pytest.fixture
104+
def mock_read_task():
105+
with patch('openhands.core.cli.read_task') as mock_read_task:
106+
mock_read_task.return_value = None
107+
yield mock_read_task
108+
109+
110+
@pytest.fixture
111+
def mock_runtime():
112+
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
113+
mock_runtime_instance = AsyncMock()
114+
115+
mock_event_stream = MockEventStream()
116+
mock_runtime_instance.event_stream = mock_event_stream
117+
118+
mock_runtime_instance.connect = AsyncMock()
119+
120+
# Ensure status_callback is None
121+
mock_runtime_instance.status_callback = None
122+
# Mock get_microagents_from_selected_repo
123+
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
124+
mock_create_runtime.return_value = mock_runtime_instance
125+
yield mock_runtime_instance
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_cli_greeting(
130+
mock_runtime, mock_controller, mock_config, mock_agent, mock_memory, mock_read_task
131+
):
132+
buffer = StringIO()
133+
134+
with create_app_session(
135+
input=create_pipe_input(), output=create_output(stdout=buffer)
136+
):
137+
mock_controller.status_callback = None
138+
139+
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
140+
141+
await asyncio.sleep(0.1)
142+
143+
hello_response = MessageAction(content='Ping')
144+
hello_response._source = EventSource.AGENT
145+
mock_runtime.event_stream.add_event(hello_response, EventSource.AGENT)
146+
147+
state_change = AgentStateChangedObservation(
148+
content='Awaiting user input', agent_state=AgentState.AWAITING_USER_INPUT
149+
)
150+
state_change._source = EventSource.AGENT
151+
mock_runtime.event_stream.add_event(state_change, EventSource.AGENT)
152+
153+
stop_event = AgentStateChangedObservation(
154+
content='Stop', agent_state=AgentState.STOPPED
155+
)
156+
stop_event._source = EventSource.AGENT
157+
mock_runtime.event_stream.add_event(stop_event, EventSource.AGENT)
158+
159+
mock_controller.state.agent_state = AgentState.STOPPED
160+
161+
try:
162+
await asyncio.wait_for(main_task, timeout=1.0)
163+
except asyncio.TimeoutError:
164+
main_task.cancel()
165+
166+
buffer.seek(0)
167+
output = buffer.read()
168+
169+
assert 'Ping' in output

tests/unit/test_cli_sid.py

+36-25
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import asyncio
22
from argparse import Namespace
3+
from io import StringIO
34
from pathlib import Path
45
from unittest.mock import AsyncMock, Mock, patch
56

67
import pytest
8+
from prompt_toolkit.application import create_app_session
9+
from prompt_toolkit.input import create_pipe_input
10+
from prompt_toolkit.output import create_output
711

812
from openhands.core.cli import main
913
from openhands.core.config import AppConfig
@@ -83,34 +87,41 @@ def mock_config(task_file: Path):
8387

8488
@pytest.mark.asyncio
8589
async def test_cli_session_id_output(
86-
mock_runtime, mock_agent, mock_controller, mock_config, capsys
90+
mock_runtime, mock_agent, mock_controller, mock_config
8791
):
8892
# status_callback is set when initializing the runtime
8993
mock_controller.status_callback = None
9094

95+
buffer = StringIO()
96+
9197
# Use input patch just for the exit command
9298
with patch('builtins.input', return_value='exit'):
93-
# Create a task for main
94-
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
95-
96-
# Give it a moment to display the session ID
97-
await asyncio.sleep(0.1)
98-
99-
# Trigger agent state change to STOPPED to end the main loop
100-
event = AgentStateChangedObservation(
101-
content='Stop', agent_state=AgentState.STOPPED
102-
)
103-
event._source = EventSource.AGENT
104-
await mock_runtime.event_stream.add_event(event)
105-
106-
# Wait for main to finish with a timeout
107-
try:
108-
await asyncio.wait_for(main_task, timeout=1.0)
109-
except asyncio.TimeoutError:
110-
main_task.cancel()
111-
112-
# Check the output
113-
captured = capsys.readouterr()
114-
assert 'Session ID:' in captured.out
115-
# Also verify that our task message was processed
116-
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)
99+
with create_app_session(
100+
input=create_pipe_input(), output=create_output(stdout=buffer)
101+
):
102+
# Create a task for main
103+
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
104+
105+
# Give it a moment to display the session ID
106+
await asyncio.sleep(0.1)
107+
108+
# Trigger agent state change to STOPPED to end the main loop
109+
event = AgentStateChangedObservation(
110+
content='Stop', agent_state=AgentState.STOPPED
111+
)
112+
event._source = EventSource.AGENT
113+
await mock_runtime.event_stream.add_event(event)
114+
115+
# Wait for main to finish with a timeout
116+
try:
117+
await asyncio.wait_for(main_task, timeout=1.0)
118+
except asyncio.TimeoutError:
119+
main_task.cancel()
120+
121+
buffer.seek(0)
122+
output = buffer.read()
123+
124+
# Check the output
125+
assert 'Session ID:' in output
126+
# Also verify that our task message was processed
127+
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)

0 commit comments

Comments
 (0)