From 85514b44d74392ba91091d2fe5a5fd868dffca9a Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Wed, 19 Mar 2025 21:32:54 +0000 Subject: [PATCH 1/2] Move incoming message stream from BaseSession to ServerSession MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes GitHub issue #201 by moving the incoming message stream and related methods from BaseSession to ServerSession where they are actually needed. This change follows the principle of having functionality only where it's required. GitHub-Issue:#201 🤖 Generated with [Claude Code](https://claude.ai/code) --- src/mcp/server/session.py | 24 ++++++++++++++++++++++++ src/mcp/shared/session.py | 37 ++++++++++--------------------------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 788bb9f83..d35e02e94 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -61,6 +61,12 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") +ServerRequestResponder = ( + RequestResponder[types.ClientRequest, types.ServerResult] + | types.ClientNotification + | Exception +) + class ServerSession( BaseSession[ @@ -85,6 +91,15 @@ def __init__( ) self._initialization_state = InitializationState.NotInitialized self._init_options = init_options + self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( + anyio.create_memory_object_stream[ServerRequestResponder]() + ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_reader.aclose() + ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_writer.aclose() + ) @property def client_params(self) -> types.InitializeRequestParams | None: @@ -291,3 +306,12 @@ async def send_prompt_list_changed(self) -> None: ) ) ) + + async def _handle_incoming(self, req: ServerRequestResponder) -> None: + return await self._incoming_message_stream_writer.send(req) + + @property + def incoming_messages( + self, + ) -> MemoryObjectReceiveStream[ServerRequestResponder]: + return self._incoming_message_stream_reader diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 31f888246..7f02d0a6b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -182,19 +182,6 @@ def __init__( self._in_flight = {} self._exit_stack = AsyncExitStack() - self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream[ - RequestResponder[ReceiveRequestT, SendResultT] - | ReceiveNotificationT - | Exception - ]() - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_reader.aclose() - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_writer.aclose() - ) async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() @@ -300,11 +287,10 @@ async def _receive_loop(self) -> None: async with ( self._read_stream, self._write_stream, - self._incoming_message_stream_writer, ): async for message in self._read_stream: if isinstance(message, Exception): - await self._incoming_message_stream_writer.send(message) + await self._handle_incoming(message) elif isinstance(message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( message.root.model_dump( @@ -325,7 +311,7 @@ async def _receive_loop(self) -> None: self._in_flight[responder.request_id] = responder await self._received_request(responder) if not responder._completed: - await self._incoming_message_stream_writer.send(responder) + await self._handle_incoming(responder) elif isinstance(message.root, JSONRPCNotification): try: @@ -341,9 +327,7 @@ async def _receive_loop(self) -> None: await self._in_flight[cancelled_id].cancel() else: await self._received_notification(notification) - await self._incoming_message_stream_writer.send( - notification - ) + await self._handle_incoming(notification) except Exception as e: # For other validation errors, log and continue logging.warning( @@ -355,7 +339,7 @@ async def _receive_loop(self) -> None: if stream: await stream.send(message.root) else: - await self._incoming_message_stream_writer.send( + await self._handle_incoming( RuntimeError( "Received response with an unknown " f"request ID: {message}" @@ -387,12 +371,11 @@ async def send_progress_notification( processed. """ - @property - def incoming_messages( + async def _handle_incoming( self, - ) -> MemoryObjectReceiveStream[ - RequestResponder[ReceiveRequestT, SendResultT] + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT - | Exception - ]: - return self._incoming_message_stream_reader + | Exception, + ) -> None: + """A generic handler for incoming messages. Overwritten by subclasses.""" + await anyio.lowlevel.checkpoint() From 4e0de002624a4e12dc1c99220bf5d5f7139a06ee Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Wed, 19 Mar 2025 22:10:05 +0000 Subject: [PATCH 2/2] Handle message callbacks in ClientSession MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change adds a message_handler callback to ClientSession to allow for direct handling of incoming messages instead of requiring an async iterator. The change simplifies the client code by removing the need for a separate receive loop task. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/mcp/client/__main__.py | 27 ++++++----- src/mcp/client/session.py | 30 ++++++++++++ src/mcp/server/session.py | 4 +- src/mcp/shared/memory.py | 10 +++- src/mcp/shared/session.py | 2 +- tests/client/test_logging_callback.py | 70 +++++++++++++-------------- tests/client/test_session.py | 21 +++++--- tests/issues/test_88_random_error.py | 40 ++++++++------- tests/server/test_session.py | 19 +++++--- 9 files changed, 137 insertions(+), 86 deletions(-) diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index baf815c0e..39b4f45c1 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -7,9 +7,11 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.session import RequestResponder from mcp.types import JSONRPCMessage if not sys.warnoptions: @@ -21,26 +23,25 @@ logger = logging.getLogger("client") -async def receive_loop(session: ClientSession): - logger.info("Starting receive loop") - async for message in session.incoming_messages: - if isinstance(message, Exception): - logger.error("Error: %s", message) - continue +async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, +) -> None: + if isinstance(message, Exception): + logger.error("Error: %s", message) + return - logger.info("Received message from server: %s", message) + logger.info("Received message from server: %s", message) async def run_session( read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], ): - async with ( - ClientSession(read_stream, write_stream) as session, - anyio.create_task_group() as tg, - ): - tg.start_soon(receive_loop, session) - + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: logger.info("Initializing session") await session.initialize() logger.info("Initialized") diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8acf32950..65d5e11e2 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,6 +1,7 @@ from datetime import timedelta from typing import Any, Protocol +import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl, TypeAdapter @@ -31,6 +32,23 @@ async def __call__( ) -> None: ... +class MessageHandlerFnT(Protocol): + async def __call__( + self, + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: ... + + +async def _default_message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, +) -> None: + await anyio.lowlevel.checkpoint() + + async def _default_sampling_callback( context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, @@ -78,6 +96,7 @@ def __init__( sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, ) -> None: super().__init__( read_stream, @@ -89,6 +108,7 @@ def __init__( self._sampling_callback = sampling_callback or _default_sampling_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback + self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() @@ -337,10 +357,20 @@ async def _received_request( types.ClientResult(root=types.EmptyResult()) ) + async def _handle_incoming( + self, + req: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + """Handle incoming messages by forwarding to the message handler.""" + await self._message_handler(req) + async def _received_notification( self, notification: types.ServerNotification ) -> None: """Handle notifications from the server.""" + # Process specific notification types match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index d35e02e94..568ecd4b9 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -92,7 +92,7 @@ def __init__( self._initialization_state = InitializationState.NotInitialized self._init_options = init_options self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream[ServerRequestResponder]() + anyio.create_memory_object_stream[ServerRequestResponder](0) ) self._exit_stack.push_async_callback( lambda: self._incoming_message_stream_reader.aclose() @@ -308,7 +308,7 @@ async def send_prompt_list_changed(self) -> None: ) async def _handle_incoming(self, req: ServerRequestResponder) -> None: - return await self._incoming_message_stream_writer.send(req) + await self._incoming_message_stream_writer.send(req) @property def incoming_messages( diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 495f0c1e5..8ab39f4dd 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -9,7 +9,13 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT +from mcp.client.session import ( + ClientSession, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) from mcp.server import Server from mcp.types import JSONRPCMessage @@ -57,6 +63,7 @@ async def create_connected_server_and_client_session( sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -86,6 +93,7 @@ async def create_connected_server_and_client_session( sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, logging_callback=logging_callback, + message_handler=message_handler, ) as client_session: await client_session.initialize() yield client_session diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 7f02d0a6b..da61cccd8 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -378,4 +378,4 @@ async def _handle_incoming( | Exception, ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" - await anyio.lowlevel.checkpoint() + pass diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index ead4f0925..3dad59243 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -1,11 +1,12 @@ from typing import List, Literal -import anyio import pytest +import mcp.types as types from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, ) +from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, TextContent, @@ -46,40 +47,37 @@ async def test_tool_with_log( ) return True - async with anyio.create_task_group() as tg: - async with create_session( - server._mcp_server, logging_callback=logging_collector - ) as client_session: + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message - async def listen_session(): - try: - async for message in client_session.incoming_messages: - if isinstance(message, Exception): - raise message - except anyio.EndOfStream: - pass + async with create_session( + server._mcp_server, + logging_callback=logging_collector, + message_handler=message_handler, + ) as client_session: + # First verify our test tool works + result = await client_session.call_tool("test_tool", {}) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" - tg.start_soon(listen_session) - - # First verify our test tool works - result = await client_session.call_tool("test_tool", {}) - assert result.isError is False - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "true" - - # Now send a log message via our tool - log_result = await client_session.call_tool( - "test_tool_with_log", - { - "message": "Test log message", - "level": "info", - "logger": "test_logger", - }, - ) - assert log_result.isError is False - assert len(logging_collector.log_messages) == 1 - assert logging_collector.log_messages[ - 0 - ] == LoggingMessageNotificationParams( - level="info", logger="test_logger", data="Test log message" - ) + # Now send a log message via our tool + log_result = await client_session.call_tool( + "test_tool_with_log", + { + "message": "Test log message", + "level": "info", + "logger": "test_logger", + }, + ) + assert log_result.isError is False + assert len(logging_collector.log_messages) == 1 + assert logging_collector.log_messages[0] == LoggingMessageNotificationParams( + level="info", logger="test_logger", data="Test log message" + ) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 7d579cdac..f250a05bb 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,7 +1,9 @@ import anyio import pytest +import mcp.types as types from mcp.client.session import ClientSession +from mcp.shared.session import RequestResponder from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientNotification, @@ -75,13 +77,21 @@ async def mock_server(): ) ) - async def listen_session(): - async for message in session.incoming_messages: - if isinstance(message, Exception): - raise message + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message async with ( - ClientSession(server_to_client_receive, client_to_server_send) as session, + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -89,7 +99,6 @@ async def listen_session(): server_to_client_receive, ): tg.start_soon(mock_server) - tg.start_soon(listen_session) result = await session.initialize() # Assert the result diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 8609c209e..6ac98ca7b 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -6,6 +6,7 @@ import anyio import pytest +from anyio.abc import TaskStatus from mcp.client.session import ClientSession from mcp.server.lowlevel import Server @@ -54,15 +55,21 @@ async def slow_tool( return [TextContent(type="text", text=f"fast {request_count}")] return [TextContent(type="text", text=f"unknown {request_count}")] - async def server_handler(read_stream, write_stream): - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - raise_exceptions=True, - ) - - async def client(read_stream, write_stream): + async def server_handler( + read_stream, + write_stream, + task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, + ): + with anyio.CancelScope() as scope: + task_status.started(scope) # type: ignore + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + raise_exceptions=True, + ) + + async def client(read_stream, write_stream, scope): # Use a timeout that's: # - Long enough for fast operations (>10ms) # - Short enough for slow operations (<200ms) @@ -90,22 +97,13 @@ async def client(read_stream, write_stream): # proving server is still responsive result = await session.call_tool("fast") assert result.content == [TextContent(type="text", text="fast 3")] + scope.cancel() # Run server and client in separate task groups to avoid cancellation server_writer, server_reader = anyio.create_memory_object_stream(1) client_writer, client_reader = anyio.create_memory_object_stream(1) - server_ready = anyio.Event() - - async def wrapped_server_handler(read_stream, write_stream): - server_ready.set() - await server_handler(read_stream, write_stream) - async with anyio.create_task_group() as tg: - tg.start_soon(wrapped_server_handler, server_reader, client_writer) - # Wait for server to start and initialize - with anyio.fail_after(1): # Timeout after 1 second - await server_ready.wait() + scope = await tg.start(server_handler, server_reader, client_writer) # Run client in a separate task to avoid cancellation - async with anyio.create_task_group() as client_tg: - client_tg.start_soon(client, client_reader, server_writer) + tg.start_soon(client, client_reader, server_writer, scope) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 333196c96..561a94b64 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -1,11 +1,13 @@ import anyio import pytest +import mcp.types as types from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, InitializedNotification, @@ -25,10 +27,14 @@ async def test_server_session_initialize(): JSONRPCMessage ](1) - async def run_client(client: ClientSession): - async for message in client_session.incoming_messages: - if isinstance(message, Exception): - raise message + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message received_initialized = False @@ -57,11 +63,12 @@ async def run_server(): try: async with ( ClientSession( - server_to_client_receive, client_to_server_send + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, ) as client_session, anyio.create_task_group() as tg, ): - tg.start_soon(run_client, client_session) tg.start_soon(run_server) await client_session.initialize()