diff --git a/examples/data-analysis/fastagent.config.yaml b/examples/data-analysis/fastagent.config.yaml index c4acd25d..6cb17659 100644 --- a/examples/data-analysis/fastagent.config.yaml +++ b/examples/data-analysis/fastagent.config.yaml @@ -1,4 +1,4 @@ -default_model: sonnet +default_model: azure-gpt-4o # on windows, adjust the mount point to be the full path e.g. x:/temp/data-analysis/mount-point:/mnt/data/ diff --git a/examples/data-analysis/mount-point/attrition_age_distribution.png b/examples/data-analysis/mount-point/attrition_age_distribution.png new file mode 100644 index 00000000..c59205d6 Binary files /dev/null and b/examples/data-analysis/mount-point/attrition_age_distribution.png differ diff --git a/examples/data-analysis/mount-point/attrition_by_jobrole.png b/examples/data-analysis/mount-point/attrition_by_jobrole.png new file mode 100644 index 00000000..5f1ad34c Binary files /dev/null and b/examples/data-analysis/mount-point/attrition_by_jobrole.png differ diff --git a/examples/data-analysis/mount-point/attrition_department_gender.png b/examples/data-analysis/mount-point/attrition_department_gender.png new file mode 100644 index 00000000..c595e07b Binary files /dev/null and b/examples/data-analysis/mount-point/attrition_department_gender.png differ diff --git a/examples/data-analysis/mount-point/attrition_distribution.png b/examples/data-analysis/mount-point/attrition_distribution.png new file mode 100644 index 00000000..8812b864 Binary files /dev/null and b/examples/data-analysis/mount-point/attrition_distribution.png differ diff --git a/examples/data-analysis/mount-point/attrition_income_distribution.png b/examples/data-analysis/mount-point/attrition_income_distribution.png new file mode 100644 index 00000000..2ebc0ac3 Binary files /dev/null and b/examples/data-analysis/mount-point/attrition_income_distribution.png differ diff --git a/examples/data-analysis/mount-point/attrition_overtime.png b/examples/data-analysis/mount-point/attrition_overtime.png new file mode 100644 index 00000000..02e74963 Binary files /dev/null and b/examples/data-analysis/mount-point/attrition_overtime.png differ diff --git a/examples/data-analysis/mount-point/income_vs_attrition.png b/examples/data-analysis/mount-point/income_vs_attrition.png new file mode 100644 index 00000000..19f79739 Binary files /dev/null and b/examples/data-analysis/mount-point/income_vs_attrition.png differ diff --git a/examples/researcher/fastagent.config.yaml b/examples/researcher/fastagent.config.yaml index ba1f6f7f..ed049c8c 100644 --- a/examples/researcher/fastagent.config.yaml +++ b/examples/researcher/fastagent.config.yaml @@ -3,7 +3,7 @@ # Examples in comments below - check/change the paths. # # - +default_model: azure-gpt-4o execution_engine: asyncio logger: type: console diff --git a/examples/workflows/fastagent.config.yaml b/examples/workflows/fastagent.config.yaml index 5006c57f..3c3f023c 100644 --- a/examples/workflows/fastagent.config.yaml +++ b/examples/workflows/fastagent.config.yaml @@ -2,7 +2,7 @@ # Examples in comments below - check/change the paths. # # - +default_model: azure-gpt-4o execution_engine: asyncio logger: type: file diff --git a/fastagent.config.yaml b/fastagent.config.yaml index 5006c57f..3c3f023c 100644 --- a/fastagent.config.yaml +++ b/fastagent.config.yaml @@ -2,7 +2,7 @@ # Examples in comments below - check/change the paths. # # - +default_model: azure-gpt-4o execution_engine: asyncio logger: type: file diff --git a/schema/mcp-agent.config.schema.json b/schema/mcp-agent.config.schema.json index be607ca2..3bd32d87 100644 --- a/schema/mcp-agent.config.schema.json +++ b/schema/mcp-agent.config.schema.json @@ -524,6 +524,60 @@ "title": "OpenAISettings", "type": "object" }, + "AzureOpenAISettings": { + "additionalProperties": true, + "description": "Settings for using AzureOpenAI models in the fast-agent application.", + "properties": { + "api_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Api Key" + }, + "reasoning_effort": { + "default": "medium", + "enum": [ + "low", + "medium", + "high" + ], + "title": "Reasoning Effort", + "type": "string" + }, + "base_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Base Url" + }, + "api_version": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "API Version" + } + }, + "title": "OpenAISettings", + "type": "object" + }, "OpenTelemetrySettings": { "description": "OTEL settings for the fast-agent application.", "properties": { @@ -663,7 +717,7 @@ "type": "null" } ], - "default": "haiku", + "default": "azure-gpt-4o", "title": "Default Model" }, "temporal": { @@ -722,6 +776,18 @@ "default": null, "description": "Settings for using OpenAI models in the fast-agent application" }, + "azureopenai": { + "anyOf": [ + { + "$ref": "#/$defs/AzureOpenAISettings" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Settings for using AzureOpenAI models in the fast-agent application" + }, "deepseek": { "anyOf": [ { diff --git a/src/mcp_agent/cli/commands/setup.py b/src/mcp_agent/cli/commands/setup.py index 1545cc01..b0c106c9 100644 --- a/src/mcp_agent/cli/commands/setup.py +++ b/src/mcp_agent/cli/commands/setup.py @@ -17,10 +17,11 @@ # Accepts aliases for Anthropic Models: haiku, haiku3, sonnet, sonnet35, opus, opus3 # and OpenAI Models: gpt-4o-mini, gpt-4o, o1, o1-mini, o3-mini # -# If not specified, defaults to "haiku". +# If not specified, defaults to "azure-gpt-4o". # Can be overriden with a command line switch --model=, or within the Agent constructor. -default_model: haiku +# default_model: gpt-4o +default_model: azure-gpt-4o # Logging and Console Configuration: logger: @@ -59,6 +60,10 @@ openai: api_key: +azureopenai: + api_key: + base_url: + api_version: anthropic: api_key: @@ -215,7 +220,7 @@ def init( if "fastagent.secrets.yaml" in created: console.print("\n[yellow]Important:[/yellow] Remember to:") console.print( - "1. Add your API keys to fastagent.secrets.yaml or set OPENAI_API_KEY and ANTHROPIC_API_KEY environment variables" + "1. Add your API keys to fastagent.secrets.yaml or set OPENAI_API_KEY, AZURE_OPENAI_KEY and ANTHROPIC_API_KEY environment variables" ) console.print( "2. Keep fastagent.secrets.yaml secure and never commit it to version control" diff --git a/src/mcp_agent/config.py b/src/mcp_agent/config.py index cfc80261..14974c19 100644 --- a/src/mcp_agent/config.py +++ b/src/mcp_agent/config.py @@ -112,6 +112,19 @@ class AnthropicSettings(BaseModel): model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) +# class BedrockAnthropicSettings(BaseModel): +# """ +# Settings for using Anthropic models in the fast-agent application. +# """ + +# aws_access_key: str | None = None +# aws_secret_key: str | None = None +# # aws_session_token: str | None = None + +# base_url: str | None = None + +# model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + class OpenAISettings(BaseModel): """ @@ -125,6 +138,19 @@ class OpenAISettings(BaseModel): model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) +class AzureOpenAISettings(BaseModel): + """ + Settings for using Azure OpenAI models in the MCP Agent application. + """ + + api_key: str | None = None + reasoning_effort: Literal["low", "medium", "high"] = "medium" + + base_url: str | None = None + api_version: str | None = None + + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) class DeepSeekSettings(BaseModel): """ @@ -256,7 +282,7 @@ class Settings(BaseSettings): execution_engine: Literal["asyncio", "temporal"] = "asyncio" """Execution engine for the fast-agent application""" - default_model: str | None = "haiku" + default_model: str | None = "azure-gpt-4o" """ Default model for agents. Format is provider.model_name., for example openai.o3-mini.low Aliases are provided for common models e.g. sonnet, haiku, gpt-4o, o3-mini etc. @@ -267,12 +293,18 @@ class Settings(BaseSettings): anthropic: AnthropicSettings | None = None """Settings for using Anthropic models in the fast-agent application""" + # bedrockanthropic: BedrockAnthropicSettings | None = None + # """Settings for using Anthropic models in the fast-agent application""" + otel: OpenTelemetrySettings | None = OpenTelemetrySettings() """OpenTelemetry logging settings for the fast-agent application""" openai: OpenAISettings | None = None """Settings for using OpenAI models in the fast-agent application""" + azureopenai: AzureOpenAISettings | None = None + """Settings for using Azure OpenAI models in the MCP Agent application""" + deepseek: DeepSeekSettings | None = None """Settings for using DeepSeek models in the fast-agent application""" diff --git a/src/mcp_agent/llm/model_factory.py b/src/mcp_agent/llm/model_factory.py index fdc12442..83faa435 100644 --- a/src/mcp_agent/llm/model_factory.py +++ b/src/mcp_agent/llm/model_factory.py @@ -8,9 +8,11 @@ from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM from mcp_agent.llm.augmented_llm_playback import PlaybackLLM from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM +# from mcp_agent.llm.providers.augmented_llm_bedrockanthropic import BedrockAnthropicAugmentedLLM from mcp_agent.llm.providers.augmented_llm_deepseek import DeepSeekAugmentedLLM from mcp_agent.llm.providers.augmented_llm_generic import GenericAugmentedLLM from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM +from mcp_agent.llm.providers.augmented_llm_azureopenai import AzureOpenAIAugmentedLLM from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM from mcp_agent.mcp.interfaces import AugmentedLLMProtocol @@ -20,7 +22,9 @@ # Type alias for LLM classes LLMClass = Union[ Type[AnthropicAugmentedLLM], + # Type[BedrockAnthropicAugmentedLLM], Type[OpenAIAugmentedLLM], + Type[AzureOpenAIAugmentedLLM], Type[PassthroughLLM], Type[PlaybackLLM], Type[DeepSeekAugmentedLLM], @@ -32,7 +36,9 @@ class Provider(Enum): """Supported LLM providers""" ANTHROPIC = auto() + # BEDROCKANTHROPIC = auto() OPENAI = auto() + AZUREOPENAI = auto() FAST_AGENT = auto() DEEPSEEK = auto() GENERIC = auto() @@ -62,7 +68,9 @@ class ModelFactory: # Mapping of provider strings to enum values PROVIDER_MAP = { "anthropic": Provider.ANTHROPIC, + # "bedrockanthropic": Provider.BEDROCKANTHROPIC, "openai": Provider.OPENAI, + "azureopenai": Provider.AZUREOPENAI, "fast-agent": Provider.FAST_AGENT, "deepseek": Provider.DEEPSEEK, "generic": Provider.GENERIC, @@ -89,6 +97,12 @@ class ModelFactory: "o1": Provider.OPENAI, "o1-preview": Provider.OPENAI, "o3-mini": Provider.OPENAI, + "azure-gpt-4o": Provider.AZUREOPENAI, + "azure-gpt-4o-mini": Provider.AZUREOPENAI, + "azure-o1-mini": Provider.AZUREOPENAI, + "azure-o1": Provider.AZUREOPENAI, + "azure-o1-preview": Provider.AZUREOPENAI, + "azure-o3-mini": Provider.AZUREOPENAI, "claude-3-haiku-20240307": Provider.ANTHROPIC, "claude-3-5-haiku-20241022": Provider.ANTHROPIC, "claude-3-5-haiku-latest": Provider.ANTHROPIC, @@ -99,6 +113,16 @@ class ModelFactory: "claude-3-7-sonnet-latest": Provider.ANTHROPIC, "claude-3-opus-20240229": Provider.ANTHROPIC, "claude-3-opus-latest": Provider.ANTHROPIC, + # "anthropic.claude-3-haiku-20240307-v1:0": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-5-haiku-20241022-v1:0": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-5-haiku-latest": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-5-sonnet-20240620-v1:0": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-5-sonnet-20241022-v2:0": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-5-sonnet-latest": Provider.BEDROCKANTHROPIC, + # "us.anthropic.claude-3-7-sonnet-20250219-v1:0": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-7-sonnet-latest": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-opus-20240229": Provider.BEDROCKANTHROPIC, + # "anthropic.claude-3-opus-latest": Provider.BEDROCKANTHROPIC, "deepseek-chat": Provider.DEEPSEEK, # "deepseek-reasoner": Provider.DEEPSEEK, reinstate on release } @@ -119,8 +143,10 @@ class ModelFactory: # Mapping of providers to their LLM classes PROVIDER_CLASSES: Dict[Provider, LLMClass] = { - Provider.ANTHROPIC: AnthropicAugmentedLLM, + # Provider.ANTHROPIC: AnthropicAugmentedLLM, + # Provider.BEDROCKANTHROPIC: BedrockAnthropicAugmentedLLM, Provider.OPENAI: OpenAIAugmentedLLM, + Provider.AZUREOPENAI: AzureOpenAIAugmentedLLM, Provider.FAST_AGENT: PassthroughLLM, Provider.DEEPSEEK: DeepSeekAugmentedLLM, Provider.GENERIC: GenericAugmentedLLM, diff --git a/src/mcp_agent/llm/providers/augmented_llm_azureopenai.py b/src/mcp_agent/llm/providers/augmented_llm_azureopenai.py new file mode 100644 index 00000000..a091c126 --- /dev/null +++ b/src/mcp_agent/llm/providers/augmented_llm_azureopenai.py @@ -0,0 +1,437 @@ +import os +from typing import List, Tuple, Type + +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, +) +from openai import AuthenticationError, OpenAI, AzureOpenAI + +# from openai.types.beta.chat import +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from pydantic_core import from_json +from rich.text import Text + +from mcp_agent.core.exceptions import ProviderKeyError +from mcp_agent.core.prompt import Prompt +from mcp_agent.llm.augmented_llm import ( + AugmentedLLM, + ModelT, + RequestParams, +) +from mcp_agent.llm.providers.multipart_converter_openai import OpenAIConverter +from mcp_agent.llm.providers.sampling_converter_openai import ( + OpenAISamplingConverter, +) +from mcp_agent.logging.logger import get_logger +from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart + +_logger = get_logger(__name__) + +DEFAULT_AZURE_OPENAI_MODEL = "azure-gpt-4o" +DEFAULT_REASONING_EFFORT = "medium" + + +class AzureOpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]): + """ + The basic building block of agentic systems is an LLM enhanced with augmentations + such as retrieval, tools, and memory provided from a collection of MCP servers. + This implementation uses OpenAI's ChatCompletion as the LLM. + """ + + def __init__(self, provider_name: str = "AzureOpenAI", *args, **kwargs) -> None: + # Set type_converter before calling super().__init__ + if "type_converter" not in kwargs: + kwargs["type_converter"] = OpenAISamplingConverter + + super().__init__(*args, **kwargs) + + self.provider = provider_name + # Initialize logger with name if available + self.logger = get_logger(f"{__name__}.{self.name}" if self.name else __name__) + + # Set up reasoning-related attributes + self._reasoning_effort = kwargs.get("reasoning_effort", None) + if self.context and self.context.config and self.context.config.openai: + if self._reasoning_effort is None and hasattr( + self.context.config.openai, "reasoning_effort" + ): + self._reasoning_effort = self.context.config.openai.reasoning_effort + + # Determine if we're using a reasoning model + chosen_model = self.default_request_params.model if self.default_request_params else None + self._reasoning = chosen_model and ( + chosen_model.startswith("azure-o3") or chosen_model.startswith("azure-o1") + ) + if self._reasoning: + self.logger.info( + f"Using reasoning model '{chosen_model}' with '{self._reasoning_effort}' reasoning effort" + ) + + def _initialize_default_params(self, kwargs: dict) -> RequestParams: + """Initialize AzureOpenAI-specific default parameters""" + chosen_model = kwargs.get("model", DEFAULT_AZURE_OPENAI_MODEL) + + return RequestParams( + model=chosen_model, + systemPrompt=self.instruction, + parallel_tool_calls=True, + max_iterations=10, + use_history=True, + ) + + def _api_key(self) -> str: + config = self.context.config + api_key = None + + if hasattr(config, "azureopenai") and config.azureopenai: + api_key = config.azureopenai.api_key + if api_key == "": + api_key = None + + if api_key is None: + api_key = os.getenv("AZURE_OPENAI_API_KEY") + + if not api_key: + raise ProviderKeyError( + "AzureOpenAI API key not configured", + "The AzureOpenAI API key is required but not set.\n" + "Add it to your configuration file under azureopenai.api_key\n" + "Or set the AZURE_OPENAI_API_KEY environment variable", + ) + return api_key + + def _base_url(self) -> str: + return self.context.config.azureopenai.base_url if self.context.config.azureopenai else None + + def _api_version(self) -> str: + return self.context.config.azureopenai.api_version if self.context.config.azureopenai else None + + async def generate_internal( + self, + message, + request_params: RequestParams | None = None, + ) -> List[TextContent | ImageContent | EmbeddedResource]: + """ + Process a query using an LLM and available tools. + The default implementation uses OpenAI's ChatCompletion as the LLM. + Override this method to use a different LLM. + """ + + try: + azureopenai_client = AzureOpenAI(api_key=self._api_key(), azure_endpoint=self._base_url(), api_version=self._api_version()) + messages: List[ChatCompletionMessageParam] = [] + params = self.get_request_params(request_params) + except AuthenticationError as e: + raise ProviderKeyError( + "Invalid Azure OpenAI API key", + "The configured Azure OpenAI API key was rejected.\n" + "Please check that your API key is valid and not expired.", + ) from e + + system_prompt = self.instruction or params.systemPrompt + if system_prompt: + messages.append(ChatCompletionSystemMessageParam(role="system", content=system_prompt)) + + # Always include prompt messages, but only include conversation history + # if use_history is True + messages.extend(self.history.get(include_history=params.use_history)) + + if isinstance(message, str): + messages.append(ChatCompletionUserMessageParam(role="user", content=message)) + elif isinstance(message, list): + messages.extend(message) + else: + messages.append(message) + + response = await self.aggregator.list_tools() + available_tools: List[ChatCompletionToolParam] | None = [ + ChatCompletionToolParam( + type="function", + function={ + "name": tool.name, + "description": tool.description, + "parameters": tool.inputSchema, + # TODO: saqadri - determine if we should specify "strict" to True by default + }, + ) + for tool in response.tools + ] + if not available_tools: + available_tools = None # deepseek does not allow empty array + + responses: List[TextContent | ImageContent | EmbeddedResource] = [] + model = self.default_request_params.model + + # we do NOT send stop sequences as this causes errors with mutlimodal processing + for i in range(params.max_iterations): + arguments = { + "model": model or "azure-gpt-4o", + "messages": messages, + "tools": available_tools, + } + if self._reasoning: + arguments = { + **arguments, + "max_completion_tokens": params.maxTokens, + "reasoning_effort": self._reasoning_effort, + } + else: + arguments = {**arguments, "max_tokens": params.maxTokens} + if available_tools: + arguments["parallel_tool_calls"] = params.parallel_tool_calls + + if params.metadata: + arguments = {**arguments, **params.metadata} + + self.logger.debug(f"{arguments}") + self._log_chat_progress(self.chat_turn(), model=model) + + executor_result = await self.executor.execute( + azureopenai_client.chat.completions.create, **arguments + ) + + response = executor_result[0] + + self.logger.debug( + "Azure OpenAI ChatCompletion response:", + data=response, + ) + + if isinstance(response, AuthenticationError): + raise ProviderKeyError( + "Invalid Azure OpenAI API key", + "The configured Azure OpenAI API key was rejected.\n" + "Please check that your API key is valid and not expired.", + ) from response + elif isinstance(response, BaseException): + self.logger.error(f"Error: {response}") + break + + if not response.choices or len(response.choices) == 0: + # No response from the model, we're done + break + + choice = response.choices[0] + message = choice.message + # prep for image/audio gen models + if message.content: + responses.append(TextContent(type="text", text=message.content)) + + converted_message = self.convert_message_to_message_param(message, name=self.name) + messages.append(converted_message) + message_text = converted_message.content + if choice.finish_reason in ["tool_calls", "function_call"] and message.tool_calls: + if message_text: + await self.show_assistant_message( + message_text, + message.tool_calls[ + 0 + ].function.name, # TODO support displaying multiple tool calls + ) + else: + await self.show_assistant_message( + Text( + "the assistant requested tool calls", + style="dim green italic", + ), + message.tool_calls[0].function.name, + ) + + tool_results = [] + for tool_call in message.tool_calls: + self.show_tool_call( + available_tools, + tool_call.function.name, + tool_call.function.arguments, + ) + tool_call_request = CallToolRequest( + method="tools/call", + params=CallToolRequestParams( + name=tool_call.function.name, + arguments=from_json(tool_call.function.arguments, allow_partial=True), + ), + ) + result = await self.call_tool(tool_call_request, tool_call.id) + self.show_oai_tool_result(str(result)) + + tool_results.append((tool_call.id, result)) + responses.extend(result.content) + messages.extend(OpenAIConverter.convert_function_results_to_openai(tool_results)) + + self.logger.debug( + f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}" + ) + elif choice.finish_reason == "length": + # We have reached the max tokens limit + self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'length'") + if request_params and request_params.maxTokens is not None: + message_text = Text( + f"the assistant has reached the maximum token limit ({request_params.maxTokens})", + style="dim green italic", + ) + else: + message_text = Text( + "the assistant has reached the maximum token limit", + style="dim green italic", + ) + + await self.show_assistant_message(message_text) + # TODO: saqadri - would be useful to return the reason for stopping to the caller + break + elif choice.finish_reason == "content_filter": + # The response was filtered by the content filter + self.logger.debug( + f"Iteration {i}: Stopping because finish_reason is 'content_filter'" + ) + # TODO: saqadri - would be useful to return the reason for stopping to the caller + break + elif choice.finish_reason == "stop": + self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'stop'") + if message_text: + await self.show_assistant_message(message_text, "") + break + + # Only save the new conversation messages to history if use_history is true + # Keep the prompt messages separate + if params.use_history: + # Get current prompt messages + prompt_messages = self.history.get(include_history=False) + + # Calculate new conversation messages (excluding prompts) + new_messages = messages[len(prompt_messages) :] + + # Update conversation history + self.history.set(new_messages) + + self._log_chat_finished(model=model) + + return responses + + async def _apply_prompt_provider_specific( + self, + multipart_messages: List["PromptMessageMultipart"], + request_params: RequestParams | None = None, + ) -> PromptMessageMultipart: + # TODO -- this is very similar to Anthropic (just the converter class changes). + # TODO -- potential refactor to base class, standardize Converter interface + # Check the last message role + last_message = multipart_messages[-1] + + # Add all previous messages to history (or all messages if last is from assistant) + messages_to_add = ( + multipart_messages[:-1] if last_message.role == "user" else multipart_messages + ) + converted = [] + for msg in messages_to_add: + converted.append(OpenAIConverter.convert_to_openai(msg)) + self.history.extend(converted, is_prompt=True) + + if last_message.role == "user": + # For user messages: Generate response to the last one + self.logger.debug("Last message in prompt is from user, generating assistant response") + message_param = OpenAIConverter.convert_to_openai(last_message) + responses: List[ + TextContent | ImageContent | EmbeddedResource + ] = await self.generate_internal( + message_param, + request_params, + ) + return Prompt.assistant(*responses) + else: + # For assistant messages: Return the last message content as text + self.logger.debug("Last message in prompt is from assistant, returning it directly") + return last_message + + async def structured( + self, + prompt: List[PromptMessageMultipart], + model: Type[ModelT], + request_params: RequestParams | None = None, + ) -> Tuple[ModelT | None, PromptMessageMultipart]: + """ + Apply the prompt and return the result as a Pydantic model. + + Uses OpenAI's beta parse feature when compatible, falling back to standard + JSON parsing if the beta feature fails or is unavailable. + + Args: + prompt: List of messages to process + model: Pydantic model to parse the response into + request_params: Optional request parameters + + Returns: + The parsed response as a Pydantic model, or None if parsing fails + """ + + if not "AzureOpenAI" == self.provider: + return await super().structured(prompt, model, request_params) + + logger = get_logger(__name__) + + # First try to use OpenAI's beta.chat.completions.parse feature + try: + # Convert the multipart messages to OpenAI format + messages = [] + for msg in prompt: + messages.append(OpenAIConverter.convert_to_openai(msg)) + + # Add system prompt if available and not already present + if self.instruction and not any(m.get("role") == "system" for m in messages): + system_msg = ChatCompletionSystemMessageParam( + role="system", content=self.instruction + ) + messages.insert(0, system_msg) + model_name = self.default_request_params.model + self.show_user_message(prompt[-1].first_text(), model_name, self.chat_turn()) + # Use the beta parse feature + try: + azureopenai_client = AzureOpenAI(api_key=self._api_key(), azure_endpoint=self._base_url(), api_version=self._api_version()) + model_name = self.default_request_params.model + + logger.debug( + f"Using OpenAI beta parse with model {model_name} for structured output" + ) + response = await self.executor.execute( + azureopenai_client.beta.chat.completions.parse, + model=model_name, + messages=messages, + response_format=model, + ) + + if response and isinstance(response[0], BaseException): + raise response[0] + parsed_result = response[0].choices[0].message + await self.show_assistant_message(parsed_result.content) + logger.debug("Successfully used OpenAI beta parse feature for structured output") + return parsed_result.parsed, Prompt.assistant(parsed_result.content) + + except (ImportError, AttributeError, NotImplementedError) as e: + # Beta feature not available, log and continue to fallback + logger.debug(f"OpenAI beta parse feature not available: {str(e)}") + # Continue to fallback + + except Exception as e: + logger.debug(f"OpenAI beta parse failed: {str(e)}, falling back to standard method") + # Continue to standard method as fallback + + # Fallback to standard method (inheriting from base class) + return await super().structured(prompt, model, request_params) + + async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest): + return request + + async def post_tool_call( + self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult + ): + return result diff --git a/src/mcp_agent/llm/providers/augmented_llm_bedrockanthropic.py b/src/mcp_agent/llm/providers/augmented_llm_bedrockanthropic.py new file mode 100644 index 00000000..1fa6ecc0 --- /dev/null +++ b/src/mcp_agent/llm/providers/augmented_llm_bedrockanthropic.py @@ -0,0 +1,407 @@ +import os +from typing import TYPE_CHECKING, List + +from mcp.types import EmbeddedResource, ImageContent, TextContent + +from mcp_agent.core.prompt import Prompt +from mcp_agent.llm.providers.multipart_converter_anthropic import ( + AnthropicConverter, +) +from mcp_agent.llm.providers.sampling_converter_anthropic import ( + AnthropicSamplingConverter, +) +from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart + +if TYPE_CHECKING: + from mcp import ListToolsResult + + +from anthropic import AuthenticationError, AnthropicBedrock +from anthropic.types import ( + Message, + MessageParam, + TextBlock, + TextBlockParam, + ToolParam, + ToolUseBlockParam, + Usage, +) +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, +) +from rich.text import Text + +from mcp_agent.core.exceptions import ProviderKeyError +from mcp_agent.llm.augmented_llm import ( + AugmentedLLM, + RequestParams, +) +from mcp_agent.logging.logger import get_logger + +DEFAULT_BEDROCK_ANTHROPIC_MODEL = "anthropic.claude-3-7-sonnet-20250219" + + +class BedrockAnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]): + """ + The basic building block of agentic systems is an LLM enhanced with augmentations + such as retrieval, tools, and memory provided from a collection of MCP servers. + Our current models can actively use these capabilities—generating their own search queries, + selecting appropriate tools, and determining what information to retain. + """ + + def __init__(self, *args, **kwargs) -> None: + self.provider = "BedrockAnthropic" + # Initialize logger - keep it simple without name reference + self.logger = get_logger(__name__) + + # Now call super().__init__ + super().__init__(*args, type_converter=AnthropicSamplingConverter, **kwargs) + + def _initialize_default_params(self, kwargs: dict) -> RequestParams: + """Initialize Anthropic-specific default parameters""" + return RequestParams( + model=kwargs.get("model", DEFAULT_BEDROCK_ANTHROPIC_MODEL), + maxTokens=4096, # default haiku3 + systemPrompt=self.instruction, + parallel_tool_calls=True, + max_iterations=10, + use_history=True, + ) + + def _base_url(self) -> str | None: + assert self.context.config + return self.context.config.bedrockanthropic.base_url if self.context.config.bedrockanthropic else None + + async def generate_internal( + self, + message_param, + request_params: RequestParams | None = None, + ) -> list[TextContent | ImageContent | EmbeddedResource]: + """ + Process a query using an LLM and available tools. + Override this method to use a different LLM. + """ + + aws_access_key = self._aws_access_key(self.context.config) + aws_secret_key = self._aws_secret_key(self.context.config) + # aws_session_token = self._aws_session_token(self.context.config) + base_url = self._base_url() + if base_url and base_url.endswith("/v1"): + base_url = base_url.rstrip("/v1") + + try: + bedrockanthropic = AnthropicBedrock(aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, base_url=base_url) + messages: List[MessageParam] = [] + params = self.get_request_params(request_params) + except AuthenticationError as e: + raise ProviderKeyError( + "Invalid Bedrock Setup", + "The configured Bedrock was rejected.\nPlease check that your keys are valid and not expired.", + ) from e + + # Always include prompt messages, but only include conversation history + # if use_history is True + messages.extend(self.history.get(include_history=params.use_history)) + + messages.append(message_param) + + tool_list: ListToolsResult = await self.aggregator.list_tools() + available_tools: List[ToolParam] = [ + ToolParam( + name=tool.name, + description=tool.description or "", + input_schema=tool.inputSchema, + ) + for tool in tool_list.tools + ] + + responses: List[TextContent | ImageContent | EmbeddedResource] = [] + + model = self.default_request_params.model + + for i in range(params.max_iterations): + self._log_chat_progress(self.chat_turn(), model=model) + arguments = { + "model": model, + "messages": messages, + "system": self.instruction or params.systemPrompt, + "stop_sequences": params.stopSequences, + "tools": available_tools, + } + + if params.maxTokens is not None: + arguments["max_tokens"] = params.maxTokens + + if params.metadata: + arguments = {**arguments, **params.metadata} + + self.logger.debug(f"{arguments}") + + executor_result = await self.executor.execute(bedrockanthropic.messages.create, **arguments) + + response = executor_result[0] + + if isinstance(response, AuthenticationError): + raise ProviderKeyError( + "Invalid Bedrock Setup", + "The configured Bedrock was rejected.\nPlease check that your keys are valid and not expired.", + ) from response + elif isinstance(response, BaseException): + error_details = str(response) + self.logger.error(f"Error: {error_details}", data=executor_result) + + # Try to extract more useful information for API errors + if hasattr(response, "status_code") and hasattr(response, "response"): + try: + error_json = response.response.json() + error_details = f"Error code: {response.status_code} - {error_json}" + except: # noqa: E722 + error_details = f"Error code: {response.status_code} - {str(response)}" + + # Convert other errors to text response + error_message = f"Error during generation: {error_details}" + response = Message( + id="error", # Required field + model="error", # Required field + role="assistant", + type="message", + content=[TextBlock(type="text", text=error_message)], + stop_reason="end_turn", # Must be one of the allowed values + usage=Usage(input_tokens=0, output_tokens=0), # Required field + ) + + self.logger.debug( + f"{model} response:", + data=response, + ) + + response_as_message = self.convert_message_to_message_param(response) + messages.append(response_as_message) + if response.content[0].type == "text": + responses.append(TextContent(type="text", text=response.content[0].text)) + + if response.stop_reason == "end_turn": + message_text = "" + for block in response_as_message["content"]: + if isinstance(block, dict) and block.get("type") == "text": + message_text += block.get("text", "") + elif hasattr(block, "type") and block.type == "text": + message_text += block.text + + await self.show_assistant_message(message_text) + + self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'end_turn'") + break + elif response.stop_reason == "stop_sequence": + # We have reached a stop sequence + self.logger.debug( + f"Iteration {i}: Stopping because finish_reason is 'stop_sequence'" + ) + break + elif response.stop_reason == "max_tokens": + # We have reached the max tokens limit + + self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'max_tokens'") + if params.maxTokens is not None: + message_text = Text( + f"the assistant has reached the maximum token limit ({params.maxTokens})", + style="dim green italic", + ) + else: + message_text = Text( + "the assistant has reached the maximum token limit", + style="dim green italic", + ) + + await self.show_assistant_message(message_text) + + break + else: + message_text = "" + for block in response_as_message["content"]: + if isinstance(block, dict) and block.get("type") == "text": + message_text += block.get("text", "") + elif hasattr(block, "type") and block.type == "text": + message_text += block.text + + # response.stop_reason == "tool_use": + # First, collect all tool uses in this turn + tool_uses = [c for c in response.content if c.type == "tool_use"] + + if tool_uses: + if message_text == "": + message_text = Text( + "the assistant requested tool calls", + style="dim green italic", + ) + + # Process all tool calls and collect results + tool_results = [] + for i, content in enumerate(tool_uses): + tool_name = content.name + tool_args = content.input + tool_use_id = content.id + + if i == 0: # Only show message for first tool use + await self.show_assistant_message(message_text, tool_name) + + self.show_tool_call(available_tools, tool_name, tool_args) + tool_call_request = CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name=tool_name, arguments=tool_args), + ) + # TODO -- support MCP isError etc. + result = await self.call_tool( + request=tool_call_request, tool_call_id=tool_use_id + ) + self.show_tool_result(result) + + # Add each result to our collection + tool_results.append((tool_use_id, result)) + responses.extend(result.content) + + messages.append(AnthropicConverter.create_tool_results_message(tool_results)) + + # Only save the new conversation messages to history if use_history is true + # Keep the prompt messages separate + if params.use_history: + # Get current prompt messages + prompt_messages = self.history.get(include_history=False) + + # Calculate new conversation messages (excluding prompts) + new_messages = messages[len(prompt_messages) :] + + # Update conversation history + self.history.set(new_messages) + + self._log_chat_finished(model=model) + + return responses + + def _aws_access_key(self, config): + aws_access_key = None + + if hasattr(config, "bedrockanthropic") and config.bedrockanthropic: + aws_access_key = config.bedrockanthropic.aws_access_key + if aws_access_key == "": + aws_access_key = None + + if aws_access_key is None: + aws_access_key = os.getenv("AWS_ACCESS_KEY_ID") + + if not aws_access_key: + raise ProviderKeyError( + "AWS_ACCESS_KEY_ID not configured", + "The AWS_ACCESS_KEY_ID is required but not set.\n" + "Add it to your configuration file under bedrockanthropic.aws_access_key " + "or set the AWS_ACCESS_KEY_ID environment variable.", + ) + + return aws_access_key + + def _aws_secret_key(self, config): + aws_secret_key = None + + if hasattr(config, "bedrockanthropic") and config.bedrockanthropic: + aws_secret_key = config.bedrockanthropic.aws_secret_key + if aws_secret_key == "": + aws_secret_key = None + + if aws_secret_key is None: + aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") + + if not aws_secret_key: + raise ProviderKeyError( + "AWS_SECRET_ACCESS_KEY not configured", + "The AWS_SECRET_ACCESS_KEY is required but not set.\n" + "Add it to your configuration file under bedrockanthropic.aws_secret_key " + "or set the AWS_SECRET_ACCESS_KEY environment variable.", + ) + + return aws_secret_key + + # def _aws_session_token(self, config): + # aws_session_token = None + + # if hasattr(config, "bedrockanthropic") and config.bedrockanthropic: + # aws_session_token = config.bedrockanthropic.aws_session_token + # if aws_session_token == "": + # aws_session_token = None + + # if aws_session_token is None: + # aws_session_token = os.getenv("AWS_SESSION_TOKEN") + + # if not aws_session_token: + # raise ProviderKeyError( + # "AWS_SESSION_TOKEN not configured", + # "The AWS_SESSION_TOKEN is required but not set.\n" + # "Add it to your configuration file under bedrockanthropic.aws_session_token " + # "or set the AWS_SESSION_TOKEN environment variable.", + # ) + + # return aws_session_token + + async def generate_messages( + self, + message_param, + request_params: RequestParams | None = None, + ) -> PromptMessageMultipart: + """ + Process a query using an LLM and available tools. + The default implementation uses Claude as the LLM. + Override this method to use a different LLM. + + """ + res = await self.generate_internal( + message_param=message_param, + request_params=request_params, + ) + return Prompt.assistant(*res) + + async def _apply_prompt_provider_specific( + self, + multipart_messages: List["PromptMessageMultipart"], + request_params: RequestParams | None = None, + ) -> PromptMessageMultipart: + # Check the last message role + last_message = multipart_messages[-1] + + # Add all previous messages to history (or all messages if last is from assistant) + messages_to_add = ( + multipart_messages[:-1] if last_message.role == "user" else multipart_messages + ) + converted = [] + for msg in messages_to_add: + converted.append(AnthropicConverter.convert_to_anthropic(msg)) + + self.history.extend(converted, is_prompt=True) + + if last_message.role == "user": + self.logger.debug("Last message in prompt is from user, generating assistant response") + message_param = AnthropicConverter.convert_to_anthropic(last_message) + return await self.generate_messages(message_param, request_params) + else: + # For assistant messages: Return the last message content as text + self.logger.debug("Last message in prompt is from assistant, returning it directly") + return last_message + + @classmethod + def convert_message_to_message_param(cls, message: Message, **kwargs) -> MessageParam: + """Convert a response object to an input parameter object to allow LLM calls to be chained.""" + content = [] + + for content_block in message.content: + if content_block.type == "text": + content.append(TextBlockParam(type="text", text=content_block.text)) + elif content_block.type == "tool_use": + content.append( + ToolUseBlockParam( + type="tool_use", + name=content_block.name, + input=content_block.input, + id=content_block.id, + ) + ) + + return MessageParam(role="assistant", content=content, **kwargs)