Skip to content

Commit ca7bd14

Browse files
xingyaowwducphamle2
authored andcommitted
refactor MCP config
1 parent 9694ce9 commit ca7bd14

File tree

3 files changed

+59
-80
lines changed

3 files changed

+59
-80
lines changed

openhands/core/config/mcp_config.py

+29-36
Original file line numberDiff line numberDiff line change
@@ -31,50 +31,43 @@ def validate_servers(self) -> None:
3131
raise ValueError(f'Invalid URL {url}: {str(e)}')
3232

3333

34+
class MCPStdioConfigEntry(BaseModel):
35+
"""Configuration for a single MCP stdio entry.
36+
37+
Attributes:
38+
command: The command to run.
39+
args: List of arguments for the command.
40+
env: Dictionary of environment variables.
41+
"""
42+
43+
command: str
44+
args: list[str] = Field(default_factory=list)
45+
env: dict[str, str] = Field(default_factory=dict)
46+
47+
model_config = {'extra': 'forbid'}
48+
49+
3450
class MCPStdioConfig(BaseModel):
3551
"""Configuration for MCP stdio settings.
3652
3753
Attributes:
38-
commands: List of commands to run.
39-
args: List of arguments for each command.
40-
envs: List of environment variable tuples for each command.
54+
tools: Dictionary of tool configurations, where keys are tool names.
4155
"""
4256

43-
commands: List[str] = Field(default_factory=list)
44-
args: List[List[str]] = Field(default_factory=list)
45-
envs: List[List[tuple[str, str]]] = Field(default_factory=list)
57+
tools: dict[str, MCPStdioConfigEntry] = Field(default_factory=dict)
4658

4759
model_config = {'extra': 'forbid'}
4860

4961
def validate_stdio(self) -> None:
50-
"""Validate that commands, args, and envs are properly configured."""
51-
52-
# Check if number of commands matches number of args lists
53-
if len(self.commands) != len(self.args):
54-
raise ValueError(
55-
f'Number of commands ({len(self.commands)}) does not match '
56-
f'number of args lists ({len(self.args)})'
57-
)
58-
59-
# Check if number of commands matches number of envs lists
60-
if len(self.commands) != len(self.envs):
61-
raise ValueError(
62-
f'Number of commands ({len(self.commands)}) does not match '
63-
f'number of envs lists ({len(self.envs)})'
64-
)
65-
66-
# Validate each environment variable tuple
67-
for i, env_list in enumerate(self.envs):
68-
for j, env_tuple in enumerate(env_list):
69-
if not isinstance(env_tuple, tuple) or len(env_tuple) != 2:
70-
raise ValueError(
71-
f'Environment variable at index {j} for command {i} must be a tuple of (key, value)'
72-
)
73-
key, value = env_tuple
74-
if not isinstance(key, str) or not isinstance(value, str):
75-
raise ValueError(
76-
f'Environment variable key and value at index {j} for command {i} must be strings'
77-
)
62+
"""Validate that tools are properly configured."""
63+
# Tool names validation
64+
for tool_name in self.tools:
65+
if not tool_name.strip():
66+
raise ValueError('Tool names cannot be empty')
67+
if not tool_name.replace('-', '').isalnum():
68+
raise ValueError(
69+
f'Invalid tool name: {tool_name}. Tool names must be alphanumeric (hyphens allowed)'
70+
)
7871

7972

8073
class MCPConfig(BaseModel):
@@ -105,11 +98,11 @@ def from_toml_section(cls, data: dict) -> dict[str, 'MCPConfig']:
10598

10699
try:
107100
# Create SSE config if present
108-
sse_config = MCPSSEConfig(**data.get('mcp-sse', {}))
101+
sse_config = MCPSSEConfig.model_validate(data.get('mcp-sse', {}))
109102
sse_config.validate_servers()
110103

111104
# Create stdio config if present
112-
stdio_config = MCPStdioConfig(**data.get('mcp-stdio', {}))
105+
stdio_config = MCPStdioConfig.model_validate(data.get('mcp-stdio', {}))
113106
stdio_config.validate_stdio()
114107

115108
# Create the main MCP config

openhands/mcp/utils.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from typing import List
2-
3-
from openhands.core.config.mcp_config import MCPConfig
1+
from openhands.core.config.mcp_config import MCPConfig, MCPStdioConfigEntry
42
from openhands.core.logger import openhands_logger as logger
53
from openhands.mcp.client import MCPClient
64

@@ -35,12 +33,10 @@ def convert_mcp_clients_to_tools(mcp_clients: list[MCPClient] | None) -> list[di
3533

3634

3735
async def create_mcp_clients(
38-
sse_mcp_server: List[str],
39-
commands: List[str],
40-
args: List[List[str]],
41-
envs: List[List[tuple[str, str]]],
42-
) -> List[MCPClient]:
43-
mcp_clients: List[MCPClient] = []
36+
sse_mcp_server: list[str],
37+
stdio_mcp_tool_configs: dict[str, MCPStdioConfigEntry],
38+
) -> list[MCPClient]:
39+
mcp_clients: list[MCPClient] = []
4440
# Initialize SSE connections
4541
if sse_mcp_server:
4642
for server_url in sse_mcp_server:
@@ -65,21 +61,24 @@ async def create_mcp_clients(
6561
)
6662

6763
# Initialize stdio connections
68-
if commands:
69-
for i, (command, command_args, command_envs) in enumerate(
70-
zip(commands, args, envs)
71-
):
64+
if stdio_mcp_tool_configs:
65+
for name, tool in stdio_mcp_tool_configs.items():
7266
logger.info(
73-
f'Initializing MCP agent for {command} with stdio connection...'
67+
f'Initializing MCP tool [{name}] for [{tool.command}] with stdio connection...'
7468
)
75-
7669
client = MCPClient()
7770
try:
78-
await client.connect_stdio(command, command_args, command_envs)
71+
await client.connect_stdio(
72+
tool.command,
73+
tool.args,
74+
[(k, v) for k, v in tool.env.items()],
75+
)
7976
mcp_clients.append(client)
80-
logger.info(f'Connected to MCP server via stdio with command {command}')
77+
logger.info(
78+
f'Connected to MCP server via stdio with command {tool.command}'
79+
)
8180
except Exception as e:
82-
logger.error(f'Failed to connect with command {command}: {str(e)}')
81+
logger.error(f'Failed to connect with command {tool.command}: {str(e)}')
8382
# Don't raise the exception, just log it and continue
8483
# Make sure to disconnect the client to clean up resources
8584
try:
@@ -88,7 +87,6 @@ async def create_mcp_clients(
8887
logger.error(
8988
f'Error during disconnect after failed connection: {str(disconnect_error)}'
9089
)
91-
9290
return mcp_clients
9391

9492

@@ -103,11 +101,10 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
103101
mcp_tools = []
104102

105103
try:
104+
logger.debug(f'Creating MCP clients with config: {mcp_config}')
106105
mcp_clients = await create_mcp_clients(
107106
mcp_config.sse.mcp_servers,
108-
mcp_config.stdio.commands,
109-
mcp_config.stdio.args,
110-
mcp_config.stdio.envs,
107+
mcp_config.stdio.tools,
111108
)
112109

113110
if not mcp_clients:
@@ -126,4 +123,5 @@ async def fetch_mcp_tools_from_config(mcp_config: MCPConfig) -> list[dict]:
126123
except Exception as disconnect_error:
127124
logger.error(f'Error disconnecting MCP client: {str(disconnect_error)}')
128125

126+
logger.debug(f'MCP tools: {mcp_tools}')
129127
return mcp_tools

openhands/runtime/action_execution_server.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import traceback
1818
from contextlib import asynccontextmanager
1919
from pathlib import Path
20-
from typing import Optional
2120
from zipfile import ZipFile
2221

2322
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
@@ -33,6 +32,7 @@
3332
from starlette.exceptions import HTTPException as StarletteHTTPException
3433
from uvicorn import run
3534

35+
from openhands.core.config.mcp_config import MCPStdioConfigEntry
3636
from openhands.core.exceptions import BrowserUnavailableException
3737
from openhands.core.logger import openhands_logger as logger
3838
from openhands.events.action import (
@@ -73,10 +73,8 @@
7373

7474
class ActionRequest(BaseModel):
7575
action: dict
76-
sse_mcp_config: Optional[list[str]] = None
77-
stdio_mcp_config: Optional[
78-
tuple[list[str], list[list[str]], list[list[tuple[str, str]]]]
79-
] = None
76+
sse_mcp_config: list[str] | None = None
77+
stdio_mcp_tool_configs: dict[str, MCPStdioConfigEntry] | None = None
8078

8179

8280
ROOT_GID = 0
@@ -192,20 +190,18 @@ def __init__(
192190
)
193191
self.memory_monitor.start_monitoring()
194192
self.sse_mcp_servers: list[str] = []
195-
self.stdio_mcp_config: tuple[
196-
list[str], list[list[str]], list[list[tuple[str, str]]]
197-
] = ([], [], [])
193+
self.stdio_mcp_tool_configs: dict[str, MCPStdioConfigEntry] = {}
198194

199195
@property
200196
def initial_cwd(self):
201197
return self._initial_cwd
202198

203199
def process_request(self, action_request: ActionRequest):
204-
# update the sse_mcp_servers and stdio_mcp_config to prepare for MCP action if needed
200+
# update the sse_mcp_servers and stdio_mcp_tool_configs to prepare for MCP action if needed
205201
if action_request.sse_mcp_config:
206202
self.sse_mcp_servers = action_request.sse_mcp_config
207-
if action_request.stdio_mcp_config:
208-
self.stdio_mcp_config = action_request.stdio_mcp_config
203+
if action_request.stdio_mcp_tool_configs:
204+
self.stdio_mcp_tool_configs = action_request.stdio_mcp_tool_configs
209205

210206
async def _init_browser_async(self):
211207
"""Initialize the browser asynchronously."""
@@ -526,21 +522,13 @@ async def browse_interactive(self, action: BrowseInteractiveAction) -> Observati
526522
return await browse(action, self.browser)
527523

528524
async def call_tool_mcp(self, action: McpAction) -> Observation:
529-
commands: list[str] = []
530-
args: list[list[str]] = []
531-
envs: list[list[tuple[str, str]]] = []
532-
if self.stdio_mcp_config:
533-
commands = self.stdio_mcp_config[0]
534-
if len(self.stdio_mcp_config) > 1:
535-
args = self.stdio_mcp_config[1]
536-
if len(self.stdio_mcp_config) > 2:
537-
envs = self.stdio_mcp_config[2]
538-
if not self.sse_mcp_servers and not commands:
525+
if not self.sse_mcp_servers and not self.stdio_mcp_tool_configs:
539526
raise ValueError('No MCP servers or stdio MCP config found')
540527

541528
logger.warning(f'SSE MCP servers: {self.sse_mcp_servers}')
542529
mcp_clients = await create_mcp_clients(
543-
self.sse_mcp_servers, commands, args, envs
530+
self.sse_mcp_servers,
531+
self.stdio_mcp_tool_configs,
544532
)
545533
logger.warn(f'MCP action received: {action}')
546534
# Find the MCP agent that has the matching tool name

0 commit comments

Comments
 (0)