Skip to content

feature: add support for custom model provider #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp_agent/core/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AgentConfig:
default_request_params: RequestParams | None = None
human_input: bool = False
agent_type: str = AgentType.BASIC.value

provider: str | None = None
def __post_init__(self) -> None:
"""Ensure default_request_params exists with proper history setting"""

Expand Down
10 changes: 9 additions & 1 deletion src/mcp_agent/core/direct_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _decorator_impl(
use_history: bool = True,
request_params: RequestParams | None = None,
human_input: bool = False,
provider: Optional[str] = None,
**extra_kwargs,
) -> Callable[[AgentCallable[P, R]], DecoratedAgentProtocol[P, R]]:
"""
Expand Down Expand Up @@ -132,6 +133,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
model=model,
use_history=use_history,
human_input=human_input,
provider=provider,
)

# Update request params if provided
Expand Down Expand Up @@ -174,6 +176,7 @@ def agent(
use_history: bool = True,
request_params: RequestParams | None = None,
human_input: bool = False,
provider: Optional[str] = None,
) -> Callable[[AgentCallable[P, R]], DecoratedAgentProtocol[P, R]]:
"""
Decorator to create and register a standard agent with type-safe signature.
Expand Down Expand Up @@ -203,6 +206,7 @@ def agent(
use_history=use_history,
request_params=request_params,
human_input=human_input,
provider=provider,
)


Expand All @@ -218,6 +222,7 @@ def orchestrator(
human_input: bool = False,
plan_type: Literal["full", "iterative"] = "full",
max_iterations: int = 30,
provider: Optional[str] = None,
) -> Callable[[AgentCallable[P, R]], DecoratedOrchestratorProtocol[P, R]]:
"""
Decorator to create and register an orchestrator agent with type-safe signature.
Expand All @@ -232,7 +237,7 @@ def orchestrator(
human_input: Whether to enable human input capabilities
plan_type: Planning approach - "full" or "iterative"
max_iterations: Maximum number of planning iterations

provider: Provider name
Returns:
A decorator that registers the orchestrator with proper type annotations
"""
Expand All @@ -259,6 +264,7 @@ def orchestrator(
child_agents=agents,
plan_type=plan_type,
max_iterations=max_iterations,
provider=provider,
),
)

Expand All @@ -274,6 +280,7 @@ def router(
use_history: bool = False,
request_params: RequestParams | None = None,
human_input: bool = False,
provider: Optional[str] = None,
) -> Callable[[AgentCallable[P, R]], DecoratedRouterProtocol[P, R]]:
"""
Decorator to create and register a router agent with type-safe signature.
Expand Down Expand Up @@ -308,6 +315,7 @@ def router(
request_params=request_params,
human_input=human_input,
router_agents=agents,
provider=provider,
),
)

Expand Down
23 changes: 17 additions & 6 deletions src/mcp_agent/core/direct_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
T = TypeVar("T") # For generic types

# Type for model factory functions
ModelFactoryFn = Callable[[Optional[str], Optional[RequestParams]], Callable[[], Any]]
ModelFactoryFn = Callable[[Optional[str], Optional[RequestParams],Optional[str]], Callable[[], Any]]


logger = get_logger(__name__)
Expand All @@ -54,6 +54,8 @@ def get_model_factory(
request_params: Optional[RequestParams] = None,
default_model: Optional[str] = None,
cli_model: Optional[str] = None,
provider: Optional[str] = None,
cli_provider: Optional[str] = None,
) -> Callable:
"""
Get model factory using specified or default model.
Expand All @@ -71,14 +73,23 @@ def get_model_factory(
"""
# Config has lowest precedence
model_spec = default_model or context.config.default_model

provider_spec = None

# Command line override has next precedence
if cli_model:
model_spec = cli_model

# Provider from command line has next precedence
if cli_provider:
provider_spec = cli_provider

# Model from decorator has highest precedence
if model:
model_spec = model

# Provider from decorator has highest precedence
if provider:
provider_spec = provider

# Update or create request_params with the final model choice
if request_params:
Expand All @@ -87,7 +98,7 @@ def get_model_factory(
request_params = RequestParams(model=model_spec)

# Let model factory handle the model string parsing and setup
return ModelFactory.create_factory(model_spec, request_params=request_params)
return ModelFactory.create_factory(model_spec, request_params=request_params,provider_name=provider_spec)


async def create_agents_by_type(
Expand Down Expand Up @@ -147,7 +158,7 @@ def model_factory_func(model=None, request_params=None):
await agent.initialize()

# Attach LLM to the agent
llm_factory = model_factory_func(model=config.model)
llm_factory = model_factory_func(model=config.model,provider=config.provider)
await agent.attach_llm(llm_factory, request_params=config.default_request_params)
result_agents[name] = agent

Expand Down Expand Up @@ -180,7 +191,7 @@ def model_factory_func(model=None, request_params=None):
await orchestrator.initialize()

# Attach LLM to the orchestrator
llm_factory = model_factory_func(model=config.model)
llm_factory = model_factory_func(model=config.model,provider=config.provider)
await orchestrator.attach_llm(
llm_factory, request_params=config.default_request_params
)
Expand Down Expand Up @@ -241,7 +252,7 @@ def model_factory_func(model=None, request_params=None):
await router.initialize()

# Attach LLM to the router
llm_factory = model_factory_func(model=config.model)
llm_factory = model_factory_func(model=config.model,provider=config.provider)
await router.attach_llm(llm_factory, request_params=config.default_request_params)
result_agents[name] = router

Expand Down
13 changes: 12 additions & 1 deletion src/mcp_agent/core/fastagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def __init__(
default="0.0.0.0",
help="Host address to bind to when running as a server with SSE transport",
)
parser.add_argument(
"--provider",
default=None,
help="Provider to use for custom models",
)

if ignore_unknown_args:
known_args, _ = parser.parse_known_args()
Expand Down Expand Up @@ -241,12 +246,14 @@ async def run(self):
validate_workflow_references(self.agents)

# Get a model factory function
def model_factory_func(model=None, request_params=None):
def model_factory_func(model=None, request_params=None, provider=None):
return get_model_factory(
self.context,
model=model,
provider=provider,
request_params=request_params,
cli_model=self.args.model if hasattr(self, "args") else None,
cli_provider=self.args.provider if hasattr(self, "args") else None,
)

# Create all agents in dependency order
Expand Down Expand Up @@ -445,6 +452,10 @@ async def start_server(
self.args.model = None
if hasattr(original_args, "model"):
self.args.model = original_args.model

self.args.provider = None
if hasattr(original_args, "provider"):
self.args.provider = original_args.provider

# Run the application, which will detect the server flag and start server mode
async with self.run():
Expand Down
18 changes: 13 additions & 5 deletions src/mcp_agent/llm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,24 @@ class ModelFactory:
}

@classmethod
def parse_model_string(cls, model_string: str) -> ModelConfig:
def parse_model_string(cls, model_string: str,provider_name:Optional[str]=None) -> ModelConfig:
"""Parse a model string into a ModelConfig object"""

provider = None
reasoning_effort = None

# Support for custom models
if provider_name is not None:
provider = cls.PROVIDER_MAP.get(provider_name.lower())
if provider:
return ModelConfig(provider=provider, model_name=model_string, reasoning_effort=None)

# Check if model string is an alias
model_string = cls.MODEL_ALIASES.get(model_string, model_string)
parts = model_string.split(".")

# Start with all parts as the model name
model_parts = parts.copy()
provider = None
reasoning_effort = None

# Check last part for reasoning effort
if len(parts) > 1 and parts[-1].lower() in cls.EFFORT_MAP:
Expand Down Expand Up @@ -175,7 +183,7 @@ def parse_model_string(cls, model_string: str) -> ModelConfig:

@classmethod
def create_factory(
cls, model_string: str, request_params: Optional[RequestParams] = None
cls, model_string: str, request_params: Optional[RequestParams] = None,provider_name:Optional[str]=None
) -> Callable[..., AugmentedLLMProtocol]:
"""
Creates a factory function that follows the attach_llm protocol.
Expand All @@ -188,7 +196,7 @@ def create_factory(
A callable that takes an agent parameter and returns an LLM instance
"""
# Parse configuration up front
config = cls.parse_model_string(model_string)
config = cls.parse_model_string(model_string,provider_name)
if config.model_name in cls.MODEL_SPECIFIC_CLASSES:
llm_class = cls.MODEL_SPECIFIC_CLASSES[config.model_name]
else:
Expand Down
Loading