From 9327cff3f6823fc5980e4ef63f9867bcf722c074 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 12 Jun 2025 13:30:39 +0300 Subject: [PATCH 1/3] Improved Trio support * Removed the asyncio-only parametrization of the anyio_backend except for test_ws, as `websockets` doesn't support Trio yet * Try to close async generators explicitly where possible * Changed nesting order for more predictable closing of async resources * Refactored `__aenter__` and `__aexit__` in some cases to exit the task group if there's a problem during initialization * Fixed test failures in client/test_auth.py where an async fixture was used in sync tests * Fixed subtle bug in `SimpleEventStore` where retrieving the stream ID was timing-dependent --- pyproject.toml | 8 +++-- src/mcp/client/streamable_http.py | 48 +++++++++++++++------------- src/mcp/server/session.py | 15 ++++++--- src/mcp/shared/session.py | 25 +++++++++------ tests/client/test_auth.py | 2 +- tests/client/test_session.py | 8 ++--- tests/conftest.py | 6 ---- tests/shared/test_streamable_http.py | 16 ++++++---- tests/shared/test_ws.py | 2 ++ uv.lock | 11 +++++-- 10 files changed, 83 insertions(+), 58 deletions(-) delete mode 100644 tests/conftest.py diff --git a/pyproject.toml b/pyproject.toml index 9ad50ab58..6e3c00d90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "python-multipart>=0.0.9", "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", + "typing_extensions>=4.12", "uvicorn>=0.23.1; sys_platform != 'emscripten'", ] @@ -48,10 +49,10 @@ required-version = ">=0.7.2" [dependency-groups] dev = [ + "anyio[trio]", "pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5", - "trio>=0.26.2", "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", @@ -122,5 +123,8 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", + # This is to avoid test failures on Trio due to httpx's failure to explicitly close + # async generators + "ignore::pytest.PytestUnraisableExceptionWarning" ] diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 471870533..15fc3393a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -8,9 +8,10 @@ import logging from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager +from contextlib import aclosing, asynccontextmanager from dataclasses import dataclass from datetime import timedelta +from typing import cast import anyio import httpx @@ -284,16 +285,18 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte """Handle SSE response from the server.""" try: event_source = EventSource(response) - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), - ) - # If the SSE event indicates completion, like returning respose/error - # break the loop - if is_complete: - break + sse_iter = cast(AsyncGenerator[ServerSentEvent], event_source.aiter_sse()) + async with aclosing(sse_iter) as items: + async for sse in items: + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + ) + # If the SSE event indicates completion, like returning respose/error + # break the loop + if is_complete: + break except Exception as e: logger.exception("Error reading SSE stream:") await ctx.read_stream_writer.send(e) @@ -434,15 +437,16 @@ async def streamablehttp_client( read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + try: + logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with httpx_client_factory( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - auth=transport.auth, - ) as client: + async with create_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout( + transport.timeout, read=transport.sse_read_timeout + ), + ) as client: + async with anyio.create_task_group() as tg: # Define callbacks that need access to tg def start_get_stream() -> None: tg.start_soon(transport.handle_get_stream, client, read_stream_writer) @@ -467,6 +471,6 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index e6611b0d4..df1dd93e9 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +from typing_extensions import Self import mcp.types as types from mcp.server.models import InitializationOptions @@ -93,10 +94,16 @@ def __init__( ) self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ - ServerRequestResponder - ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + + async def __aenter__(self) -> Self: + await super().__aenter__() + self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( + anyio.create_memory_object_stream[ServerRequestResponder](0) + ) + self._exit_stack.push_async_callback( + self._incoming_message_stream_reader.aclose + ) + return self @property def client_params(self) -> types.InitializeRequestParams | None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 294986acb..a8c9b29d7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -7,6 +7,7 @@ import anyio import httpx +from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel from typing_extensions import Self @@ -177,6 +178,8 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _exit_stack: AsyncExitStack + _task_group: TaskGroup def __init__( self, @@ -196,12 +199,19 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - self._task_group.start_soon(self._receive_loop) + async with AsyncExitStack() as exit_stack: + self._task_group = await exit_stack.enter_async_context( + anyio.create_task_group() + ) + self._task_group.start_soon(self._receive_loop) + # Using BaseSession as a context manager should not block on exit (this + # would be very surprising behavior), so make sure to cancel the tasks + # in the task group. + exit_stack.callback(self._task_group.cancel_scope.cancel) + self._exit_stack = exit_stack.pop_all() + return self async def __aexit__( @@ -210,12 +220,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - await self._exit_stack.aclose() - # Using BaseSession as a context manager should not block on exit (this - # would be very surprising behavior), so make sure to cancel the tasks - # in the task group. - self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) async def send_request( self, diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index de4eb70af..e514bb5f7 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -99,7 +99,7 @@ def oauth_token(): @pytest.fixture -async def oauth_provider(client_metadata, mock_storage): +def oauth_provider(client_metadata, mock_storage): async def mock_redirect_handler(url: str) -> None: pass diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 327d1a9e4..12c043fe7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -334,15 +334,15 @@ async def mock_server(): ) async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, ClientSession( server_to_client_receive, client_to_server_send, ) as session, anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, ): tg.start_soon(mock_server) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index af7e47993..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -import pytest - - -@pytest.fixture -def anyio_backend(): - return "asyncio" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 615e68efc..d61538c2a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -87,16 +87,17 @@ async def replay_events_after( """Replay events after the specified ID.""" # Find the index of the last event ID start_index = None - for i, (_, event_id, _) in enumerate(self._events): + stream_id = None + for i, (stream_id_, event_id, _) in enumerate(self._events): if event_id == last_event_id: start_index = i + 1 + stream_id = stream_id_ break if start_index is None: # If event ID not found, start from beginning start_index = 0 - stream_id = None # Replay events for _, event_id, message in self._events[start_index:]: await send_callback(EventMessage(message, event_id)) @@ -1003,7 +1004,8 @@ async def test_streamablehttp_client_resumption(event_server): captured_session_id = None captured_resumption_token = None captured_notifications = [] - tool_started = False + tool_started_event = anyio.Event() + session_resumption_token_received_event = anyio.Event() async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1013,12 +1015,12 @@ async def message_handler( # Look for our special notification that indicates the tool is running if isinstance(message.root, types.LoggingMessageNotification): if message.root.params.data == "Tool started": - nonlocal tool_started - tool_started = True + tool_started_event.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + session_resumption_token_received_event.set() # First, start the client session and begin the long-running tool async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( @@ -1055,8 +1057,8 @@ async def run_tool(): # Wait for the tool to start and at least one notification # and then kill the task group - while not tool_started or not captured_resumption_token: - await anyio.sleep(0.1) + await tool_started_event.wait() + await session_resumption_token_received_event.wait() tg.cancel_scope.cancel() # Store pre notifications and clear the captured notifications diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 5081f1d53..084043236 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -27,6 +27,8 @@ SERVER_NAME = "test_server_for_WS" +pytestmark = pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.fixture def server_port() -> int: diff --git a/uv.lock b/uv.lock index 180d5a9c1..6ae56f94b 100644 --- a/uv.lock +++ b/uv.lock @@ -40,6 +40,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/68/f9e9bf6324c46e6b8396610aef90ad423ec3e18c9079547ceafea3dce0ec/anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78", size = 89250, upload-time = "2024-09-19T09:28:42.699Z" }, ] +[package.optional-dependencies] +trio = [ + { name = "trio" }, +] + [[package]] name = "asttokens" version = "2.4.1" @@ -537,6 +542,7 @@ dependencies = [ { name = "python-multipart" }, { name = "sse-starlette" }, { name = "starlette" }, + { name = "typing-extensions" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ] @@ -554,6 +560,7 @@ ws = [ [package.dev-dependencies] dev = [ + { name = "anyio", extra = ["trio"] }, { name = "inline-snapshot" }, { name = "pyright" }, { name = "pytest" }, @@ -562,7 +569,6 @@ dev = [ { name = "pytest-pretty" }, { name = "pytest-xdist" }, { name = "ruff" }, - { name = "trio" }, ] docs = [ { name = "mkdocs" }, @@ -584,6 +590,7 @@ requires-dist = [ { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, + { name = "typing-extensions", specifier = ">=4.12" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] @@ -591,6 +598,7 @@ provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ + { name = "anyio", extras = ["trio"] }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=8.3.4" }, @@ -599,7 +607,6 @@ dev = [ { name = "pytest-pretty", specifier = ">=1.2.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, - { name = "trio", specifier = ">=0.26.2" }, ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, From 0ce5b79dc54377f30741e480f08df54faaefc161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 12 Jun 2025 14:08:08 +0300 Subject: [PATCH 2/3] Fixed pre-commit errors --- src/mcp/client/streamable_http.py | 4 +--- src/mcp/server/session.py | 10 ++++------ src/mcp/shared/session.py | 4 +--- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 15fc3393a..1b025e83e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -442,9 +442,7 @@ async def streamablehttp_client( async with create_mcp_http_client( headers=transport.request_headers, - timeout=httpx.Timeout( - transport.timeout, read=transport.sse_read_timeout - ), + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), ) as client: async with anyio.create_task_group() as tg: # Define callbacks that need access to tg diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index df1dd93e9..61d654744 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -97,12 +97,10 @@ def __init__( async def __aenter__(self) -> Self: await super().__aenter__() - self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream[ServerRequestResponder](0) - ) - self._exit_stack.push_async_callback( - self._incoming_message_stream_reader.aclose - ) + self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ + ServerRequestResponder + ](0) + self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose) return self @property diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index a8c9b29d7..2ff29304a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -202,9 +202,7 @@ def __init__( async def __aenter__(self) -> Self: async with AsyncExitStack() as exit_stack: - self._task_group = await exit_stack.enter_async_context( - anyio.create_task_group() - ) + self._task_group = await exit_stack.enter_async_context(anyio.create_task_group()) self._task_group.start_soon(self._receive_loop) # Using BaseSession as a context manager should not block on exit (this # would be very surprising behavior), so make sure to cancel the tasks From 83d23ad0e324c31911ed86a14d41e2290bff60e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 18 Jun 2025 16:56:20 +0300 Subject: [PATCH 3/3] Fixed uses of async generators and removed the pytest warning ignore --- pyproject.toml | 1 - src/mcp/client/streamable_http.py | 24 +++++++++++------------- tests/client/test_auth.py | 31 ++++++++++++++++--------------- tests/shared/test_sse.py | 18 ++++++++++-------- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e3c00d90..f20a81946 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,5 +126,4 @@ filterwarnings = [ "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", # This is to avoid test failures on Trio due to httpx's failure to explicitly close # async generators - "ignore::pytest.PytestUnraisableExceptionWarning" ] diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 970c3c682..bfb0f0aa1 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,7 +11,6 @@ from contextlib import aclosing, asynccontextmanager from dataclasses import dataclass from datetime import timedelta -from typing import cast import anyio import httpx @@ -241,15 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - original_request_id, - ctx.metadata.on_resumption_token_update if ctx.metadata else None, - ) - if is_complete: - break + async with aclosing(event_source.aiter_sse()) as iterator: + async for sse in iterator: + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" @@ -320,9 +320,7 @@ async def _handle_sse_response( ) -> None: """Handle SSE response from the server.""" try: - event_source = EventSource(response) - sse_iter = cast(AsyncGenerator[ServerSentEvent], event_source.aiter_sse()) - async with aclosing(sse_iter) as items: + async with aclosing(EventSource(response).aiter_sse()) as items: async for sse in items: is_complete = await self._handle_sse_event( sse, diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index e514bb5f7..ef202facd 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -5,6 +5,7 @@ import base64 import hashlib import time +from contextlib import aclosing from unittest.mock import AsyncMock, Mock, patch from urllib.parse import parse_qs, urlparse @@ -654,17 +655,17 @@ async def test_async_auth_flow_401_response(self, oauth_provider, oauth_token): mock_response = Mock() mock_response.status_code = 401 - auth_flow = oauth_provider.async_auth_flow(request) - await auth_flow.__anext__() + async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow: + await auth_flow.__anext__() - # Send 401 response - try: - await auth_flow.asend(mock_response) - except StopAsyncIteration: - pass + # Send 401 response + try: + await auth_flow.asend(mock_response) + except StopAsyncIteration: + pass - # Should clear current tokens - assert oauth_provider._current_tokens is None + # Should clear current tokens + assert oauth_provider._current_tokens is None @pytest.mark.anyio async def test_async_auth_flow_no_token(self, oauth_provider): @@ -675,14 +676,14 @@ async def test_async_auth_flow_no_token(self, oauth_provider): patch.object(oauth_provider, "initialize") as mock_init, patch.object(oauth_provider, "ensure_token") as mock_ensure, ): - auth_flow = oauth_provider.async_auth_flow(request) - updated_request = await auth_flow.__anext__() + async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow: + updated_request = await auth_flow.__anext__() - mock_init.assert_called_once() - mock_ensure.assert_called_once() + mock_init.assert_called_once() + mock_ensure.assert_called_once() - # No Authorization header should be added if no token - assert "Authorization" not in updated_request.headers + # No Authorization header should be added if no token + assert "Authorization" not in updated_request.headers @pytest.mark.anyio async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info): diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4d8f7717e..43bb3320f 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,6 +3,7 @@ import socket import time from collections.abc import AsyncGenerator, Generator +from contextlib import aclosing import anyio import httpx @@ -160,14 +161,15 @@ async def connection_test() -> None: assert response.headers["content-type"] == "text/event-stream; charset=utf-8" line_number = 0 - async for line in response.aiter_lines(): - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 + async with aclosing(response.aiter_lines()) as lines: + async for line in lines: + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 # Add timeout to prevent test from hanging if it fails with anyio.fail_after(3):