Skip to content

Improved Trio support #946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"python-multipart>=0.0.9",
"sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2",
"typing_extensions>=4.12",
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
]

Expand All @@ -48,10 +49,10 @@ required-version = ">=0.7.2"

[dependency-groups]
dev = [
"anyio[trio]",
"pyright>=1.1.391",
"pytest>=8.3.4",
"ruff>=0.8.5",
"trio>=0.26.2",
"pytest-flakefinder>=1.1.0",
"pytest-xdist>=3.6.1",
"pytest-examples>=0.0.14",
Expand Down Expand Up @@ -122,5 +123,7 @@ filterwarnings = [
# This should be fixed on Uvicorn's side.
"ignore::DeprecationWarning:websockets",
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel",
# This is to avoid test failures on Trio due to httpx's failure to explicitly close
# async generators
]
66 changes: 33 additions & 33 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from contextlib import aclosing, asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta

Expand Down Expand Up @@ -240,15 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")

async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
async with aclosing(event_source.aiter_sse()) as iterator:
async for sse in iterator:
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
Expand Down Expand Up @@ -319,18 +320,18 @@ async def _handle_sse_response(
) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
async with aclosing(EventSource(response).aiter_sse()) as items:
async for sse in items:
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
Expand Down Expand Up @@ -471,15 +472,14 @@ async def streamablehttp_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")

async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
) as client:
async with create_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
async with anyio.create_task_group() as tg:
# Define callbacks that need access to tg
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
Expand All @@ -504,6 +504,6 @@ def start_get_stream() -> None:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
7 changes: 6 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from typing_extensions import Self

import mcp.types as types
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -93,10 +94,14 @@ def __init__(
)

self._init_options = init_options

async def __aenter__(self) -> Self:
await super().__aenter__()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
ServerRequestResponder
](0)
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose)
return self

@property
def client_params(self) -> types.InitializeRequestParams | None:
Expand Down
23 changes: 13 additions & 10 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
from typing_extensions import Self
Expand Down Expand Up @@ -177,6 +178,8 @@ class BaseSession(
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_exit_stack: AsyncExitStack
_task_group: TaskGroup

def __init__(
self,
Expand All @@ -196,12 +199,17 @@ def __init__(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
async with AsyncExitStack() as exit_stack:
self._task_group = await exit_stack.enter_async_context(anyio.create_task_group())
self._task_group.start_soon(self._receive_loop)
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
exit_stack.callback(self._task_group.cancel_scope.cancel)
self._exit_stack = exit_stack.pop_all()

return self

async def __aexit__(
Expand All @@ -210,12 +218,7 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)

async def send_request(
self,
Expand Down
33 changes: 17 additions & 16 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import base64
import hashlib
import time
from contextlib import aclosing
from unittest.mock import AsyncMock, Mock, patch
from urllib.parse import parse_qs, urlparse

Expand Down Expand Up @@ -99,7 +100,7 @@ def oauth_token():


@pytest.fixture
async def oauth_provider(client_metadata, mock_storage):
def oauth_provider(client_metadata, mock_storage):
async def mock_redirect_handler(url: str) -> None:
pass

Expand Down Expand Up @@ -654,17 +655,17 @@ async def test_async_auth_flow_401_response(self, oauth_provider, oauth_token):
mock_response = Mock()
mock_response.status_code = 401

auth_flow = oauth_provider.async_auth_flow(request)
await auth_flow.__anext__()
async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow:
await auth_flow.__anext__()

# Send 401 response
try:
await auth_flow.asend(mock_response)
except StopAsyncIteration:
pass
# Send 401 response
try:
await auth_flow.asend(mock_response)
except StopAsyncIteration:
pass

# Should clear current tokens
assert oauth_provider._current_tokens is None
# Should clear current tokens
assert oauth_provider._current_tokens is None

@pytest.mark.anyio
async def test_async_auth_flow_no_token(self, oauth_provider):
Expand All @@ -675,14 +676,14 @@ async def test_async_auth_flow_no_token(self, oauth_provider):
patch.object(oauth_provider, "initialize") as mock_init,
patch.object(oauth_provider, "ensure_token") as mock_ensure,
):
auth_flow = oauth_provider.async_auth_flow(request)
updated_request = await auth_flow.__anext__()
async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow:
updated_request = await auth_flow.__anext__()

mock_init.assert_called_once()
mock_ensure.assert_called_once()
mock_init.assert_called_once()
mock_ensure.assert_called_once()

# No Authorization header should be added if no token
assert "Authorization" not in updated_request.headers
# No Authorization header should be added if no token
assert "Authorization" not in updated_request.headers

@pytest.mark.anyio
async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info):
Expand Down
8 changes: 4 additions & 4 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ async def mock_server():
)

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
ClientSession(
server_to_client_receive,
client_to_server_send,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)

Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py

This file was deleted.

18 changes: 10 additions & 8 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import socket
import time
from collections.abc import AsyncGenerator, Generator
from contextlib import aclosing

import anyio
import httpx
Expand Down Expand Up @@ -165,14 +166,15 @@ async def connection_test() -> None:
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"

line_number = 0
async for line in response.aiter_lines():
if line_number == 0:
assert line == "event: endpoint"
elif line_number == 1:
assert line.startswith("data: /messages/?session_id=")
else:
return
line_number += 1
async with aclosing(response.aiter_lines()) as lines:
async for line in lines:
if line_number == 0:
assert line == "event: endpoint"
elif line_number == 1:
assert line.startswith("data: /messages/?session_id=")
else:
return
line_number += 1

# Add timeout to prevent test from hanging if it fails
with anyio.fail_after(3):
Expand Down
16 changes: 9 additions & 7 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,17 @@ async def replay_events_after(
"""Replay events after the specified ID."""
# Find the index of the last event ID
start_index = None
for i, (_, event_id, _) in enumerate(self._events):
stream_id = None
for i, (stream_id_, event_id, _) in enumerate(self._events):
if event_id == last_event_id:
start_index = i + 1
stream_id = stream_id_
break

if start_index is None:
# If event ID not found, start from beginning
start_index = 0

stream_id = None
# Replay events
for _, event_id, message in self._events[start_index:]:
await send_callback(EventMessage(message, event_id))
Expand Down Expand Up @@ -1055,7 +1056,8 @@ async def test_streamablehttp_client_resumption(event_server):
captured_session_id = None
captured_resumption_token = None
captured_notifications = []
tool_started = False
tool_started_event = anyio.Event()
session_resumption_token_received_event = anyio.Event()
captured_protocol_version = None

async def message_handler(
Expand All @@ -1066,12 +1068,12 @@ async def message_handler(
# Look for our special notification that indicates the tool is running
if isinstance(message.root, types.LoggingMessageNotification):
if message.root.params.data == "Tool started":
nonlocal tool_started
tool_started = True
tool_started_event.set()

async def on_resumption_token_update(token: str) -> None:
nonlocal captured_resumption_token
captured_resumption_token = token
session_resumption_token_received_event.set()

# First, start the client session and begin the long-running tool
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
Expand Down Expand Up @@ -1110,8 +1112,8 @@ async def run_tool():

# Wait for the tool to start and at least one notification
# and then kill the task group
while not tool_started or not captured_resumption_token:
await anyio.sleep(0.1)
await tool_started_event.wait()
await session_resumption_token_received_event.wait()
tg.cancel_scope.cancel()

# Store pre notifications and clear the captured notifications
Expand Down
2 changes: 2 additions & 0 deletions tests/shared/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

SERVER_NAME = "test_server_for_WS"

pytestmark = pytest.mark.parametrize("anyio_backend", ["asyncio"])


@pytest.fixture
def server_port() -> int:
Expand Down
Loading
Loading