From 7459cbdffad9b6b40e4cb1a09ff16b5c94831361 Mon Sep 17 00:00:00 2001 From: pansu <34263218+pansusu@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:44:25 +0800 Subject: [PATCH] feature: add support for custom model provider --- src/mcp_agent/core/agent_types.py | 2 +- src/mcp_agent/core/direct_decorators.py | 10 +++++++++- src/mcp_agent/core/direct_factory.py | 23 +++++++++++++++++------ src/mcp_agent/core/fastagent.py | 13 ++++++++++++- src/mcp_agent/llm/model_factory.py | 18 +++++++++++++----- 5 files changed, 52 insertions(+), 14 deletions(-) diff --git a/src/mcp_agent/core/agent_types.py b/src/mcp_agent/core/agent_types.py index 248d7f8e..1aff224c 100644 --- a/src/mcp_agent/core/agent_types.py +++ b/src/mcp_agent/core/agent_types.py @@ -34,7 +34,7 @@ class AgentConfig: default_request_params: RequestParams | None = None human_input: bool = False agent_type: str = AgentType.BASIC.value - + provider: str | None = None def __post_init__(self) -> None: """Ensure default_request_params exists with proper history setting""" diff --git a/src/mcp_agent/core/direct_decorators.py b/src/mcp_agent/core/direct_decorators.py index a86ae98d..21b38d1b 100644 --- a/src/mcp_agent/core/direct_decorators.py +++ b/src/mcp_agent/core/direct_decorators.py @@ -90,6 +90,7 @@ def _decorator_impl( use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, + provider: Optional[str] = None, **extra_kwargs, ) -> Callable[[AgentCallable[P, R]], DecoratedAgentProtocol[P, R]]: """ @@ -132,6 +133,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: model=model, use_history=use_history, human_input=human_input, + provider=provider, ) # Update request params if provided @@ -174,6 +176,7 @@ def agent( use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, + provider: Optional[str] = None, ) -> Callable[[AgentCallable[P, R]], DecoratedAgentProtocol[P, R]]: """ Decorator to create and register a standard agent with type-safe signature. @@ -203,6 +206,7 @@ def agent( use_history=use_history, request_params=request_params, human_input=human_input, + provider=provider, ) @@ -218,6 +222,7 @@ def orchestrator( human_input: bool = False, plan_type: Literal["full", "iterative"] = "full", max_iterations: int = 30, + provider: Optional[str] = None, ) -> Callable[[AgentCallable[P, R]], DecoratedOrchestratorProtocol[P, R]]: """ Decorator to create and register an orchestrator agent with type-safe signature. @@ -232,7 +237,7 @@ def orchestrator( human_input: Whether to enable human input capabilities plan_type: Planning approach - "full" or "iterative" max_iterations: Maximum number of planning iterations - + provider: Provider name Returns: A decorator that registers the orchestrator with proper type annotations """ @@ -259,6 +264,7 @@ def orchestrator( child_agents=agents, plan_type=plan_type, max_iterations=max_iterations, + provider=provider, ), ) @@ -274,6 +280,7 @@ def router( use_history: bool = False, request_params: RequestParams | None = None, human_input: bool = False, + provider: Optional[str] = None, ) -> Callable[[AgentCallable[P, R]], DecoratedRouterProtocol[P, R]]: """ Decorator to create and register a router agent with type-safe signature. @@ -308,6 +315,7 @@ def router( request_params=request_params, human_input=human_input, router_agents=agents, + provider=provider, ), ) diff --git a/src/mcp_agent/core/direct_factory.py b/src/mcp_agent/core/direct_factory.py index 4f11aa50..4daa3150 100644 --- a/src/mcp_agent/core/direct_factory.py +++ b/src/mcp_agent/core/direct_factory.py @@ -28,7 +28,7 @@ T = TypeVar("T") # For generic types # Type for model factory functions -ModelFactoryFn = Callable[[Optional[str], Optional[RequestParams]], Callable[[], Any]] +ModelFactoryFn = Callable[[Optional[str], Optional[RequestParams],Optional[str]], Callable[[], Any]] logger = get_logger(__name__) @@ -54,6 +54,8 @@ def get_model_factory( request_params: Optional[RequestParams] = None, default_model: Optional[str] = None, cli_model: Optional[str] = None, + provider: Optional[str] = None, + cli_provider: Optional[str] = None, ) -> Callable: """ Get model factory using specified or default model. @@ -71,14 +73,23 @@ def get_model_factory( """ # Config has lowest precedence model_spec = default_model or context.config.default_model - + provider_spec = None + # Command line override has next precedence if cli_model: model_spec = cli_model + + # Provider from command line has next precedence + if cli_provider: + provider_spec = cli_provider # Model from decorator has highest precedence if model: model_spec = model + + # Provider from decorator has highest precedence + if provider: + provider_spec = provider # Update or create request_params with the final model choice if request_params: @@ -87,7 +98,7 @@ def get_model_factory( request_params = RequestParams(model=model_spec) # Let model factory handle the model string parsing and setup - return ModelFactory.create_factory(model_spec, request_params=request_params) + return ModelFactory.create_factory(model_spec, request_params=request_params,provider_name=provider_spec) async def create_agents_by_type( @@ -147,7 +158,7 @@ def model_factory_func(model=None, request_params=None): await agent.initialize() # Attach LLM to the agent - llm_factory = model_factory_func(model=config.model) + llm_factory = model_factory_func(model=config.model,provider=config.provider) await agent.attach_llm(llm_factory, request_params=config.default_request_params) result_agents[name] = agent @@ -180,7 +191,7 @@ def model_factory_func(model=None, request_params=None): await orchestrator.initialize() # Attach LLM to the orchestrator - llm_factory = model_factory_func(model=config.model) + llm_factory = model_factory_func(model=config.model,provider=config.provider) await orchestrator.attach_llm( llm_factory, request_params=config.default_request_params ) @@ -241,7 +252,7 @@ def model_factory_func(model=None, request_params=None): await router.initialize() # Attach LLM to the router - llm_factory = model_factory_func(model=config.model) + llm_factory = model_factory_func(model=config.model,provider=config.provider) await router.attach_llm(llm_factory, request_params=config.default_request_params) result_agents[name] = router diff --git a/src/mcp_agent/core/fastagent.py b/src/mcp_agent/core/fastagent.py index 7d0db64b..2cfa0764 100644 --- a/src/mcp_agent/core/fastagent.py +++ b/src/mcp_agent/core/fastagent.py @@ -129,6 +129,11 @@ def __init__( default="0.0.0.0", help="Host address to bind to when running as a server with SSE transport", ) + parser.add_argument( + "--provider", + default=None, + help="Provider to use for custom models", + ) if ignore_unknown_args: known_args, _ = parser.parse_known_args() @@ -241,12 +246,14 @@ async def run(self): validate_workflow_references(self.agents) # Get a model factory function - def model_factory_func(model=None, request_params=None): + def model_factory_func(model=None, request_params=None, provider=None): return get_model_factory( self.context, model=model, + provider=provider, request_params=request_params, cli_model=self.args.model if hasattr(self, "args") else None, + cli_provider=self.args.provider if hasattr(self, "args") else None, ) # Create all agents in dependency order @@ -445,6 +452,10 @@ async def start_server( self.args.model = None if hasattr(original_args, "model"): self.args.model = original_args.model + + self.args.provider = None + if hasattr(original_args, "provider"): + self.args.provider = original_args.provider # Run the application, which will detect the server flag and start server mode async with self.run(): diff --git a/src/mcp_agent/llm/model_factory.py b/src/mcp_agent/llm/model_factory.py index 40b45fc0..027921c7 100644 --- a/src/mcp_agent/llm/model_factory.py +++ b/src/mcp_agent/llm/model_factory.py @@ -137,16 +137,24 @@ class ModelFactory: } @classmethod - def parse_model_string(cls, model_string: str) -> ModelConfig: + def parse_model_string(cls, model_string: str,provider_name:Optional[str]=None) -> ModelConfig: """Parse a model string into a ModelConfig object""" + + provider = None + reasoning_effort = None + + # Support for custom models + if provider_name is not None: + provider = cls.PROVIDER_MAP.get(provider_name.lower()) + if provider: + return ModelConfig(provider=provider, model_name=model_string, reasoning_effort=None) + # Check if model string is an alias model_string = cls.MODEL_ALIASES.get(model_string, model_string) parts = model_string.split(".") # Start with all parts as the model name model_parts = parts.copy() - provider = None - reasoning_effort = None # Check last part for reasoning effort if len(parts) > 1 and parts[-1].lower() in cls.EFFORT_MAP: @@ -175,7 +183,7 @@ def parse_model_string(cls, model_string: str) -> ModelConfig: @classmethod def create_factory( - cls, model_string: str, request_params: Optional[RequestParams] = None + cls, model_string: str, request_params: Optional[RequestParams] = None,provider_name:Optional[str]=None ) -> Callable[..., AugmentedLLMProtocol]: """ Creates a factory function that follows the attach_llm protocol. @@ -188,7 +196,7 @@ def create_factory( A callable that takes an agent parameter and returns an LLM instance """ # Parse configuration up front - config = cls.parse_model_string(model_string) + config = cls.parse_model_string(model_string,provider_name) if config.model_name in cls.MODEL_SPECIFIC_CLASSES: llm_class = cls.MODEL_SPECIFIC_CLASSES[config.model_name] else: