Skip to content

Commit 82bd8bc

Browse files
authored
Properly clean up response streams in BaseSession (#515)
1 parent 1a330ac commit 82bd8bc

File tree

2 files changed

+107
-37
lines changed

2 files changed

+107
-37
lines changed

src/mcp/shared/session.py

+39-37
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def __init__(
187187
self._receive_notification_type = receive_notification_type
188188
self._session_read_timeout_seconds = read_timeout_seconds
189189
self._in_flight = {}
190-
191190
self._exit_stack = AsyncExitStack()
192191

193192
async def __aenter__(self) -> Self:
@@ -232,45 +231,48 @@ async def send_request(
232231
](1)
233232
self._response_streams[request_id] = response_stream
234233

235-
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
236-
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
237-
238-
jsonrpc_request = JSONRPCRequest(
239-
jsonrpc="2.0",
240-
id=request_id,
241-
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
242-
)
243-
244-
# TODO: Support progress callbacks
245-
246-
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
247-
248-
# request read timeout takes precedence over session read timeout
249-
timeout = None
250-
if request_read_timeout_seconds is not None:
251-
timeout = request_read_timeout_seconds.total_seconds()
252-
elif self._session_read_timeout_seconds is not None:
253-
timeout = self._session_read_timeout_seconds.total_seconds()
254-
255234
try:
256-
with anyio.fail_after(timeout):
257-
response_or_error = await response_stream_reader.receive()
258-
except TimeoutError:
259-
raise McpError(
260-
ErrorData(
261-
code=httpx.codes.REQUEST_TIMEOUT,
262-
message=(
263-
f"Timed out while waiting for response to "
264-
f"{request.__class__.__name__}. Waited "
265-
f"{timeout} seconds."
266-
),
267-
)
235+
jsonrpc_request = JSONRPCRequest(
236+
jsonrpc="2.0",
237+
id=request_id,
238+
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
268239
)
269240

270-
if isinstance(response_or_error, JSONRPCError):
271-
raise McpError(response_or_error.error)
272-
else:
273-
return result_type.model_validate(response_or_error.result)
241+
# TODO: Support progress callbacks
242+
243+
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
244+
245+
# request read timeout takes precedence over session read timeout
246+
timeout = None
247+
if request_read_timeout_seconds is not None:
248+
timeout = request_read_timeout_seconds.total_seconds()
249+
elif self._session_read_timeout_seconds is not None:
250+
timeout = self._session_read_timeout_seconds.total_seconds()
251+
252+
try:
253+
with anyio.fail_after(timeout):
254+
response_or_error = await response_stream_reader.receive()
255+
except TimeoutError:
256+
raise McpError(
257+
ErrorData(
258+
code=httpx.codes.REQUEST_TIMEOUT,
259+
message=(
260+
f"Timed out while waiting for response to "
261+
f"{request.__class__.__name__}. Waited "
262+
f"{timeout} seconds."
263+
),
264+
)
265+
)
266+
267+
if isinstance(response_or_error, JSONRPCError):
268+
raise McpError(response_or_error.error)
269+
else:
270+
return result_type.model_validate(response_or_error.result)
271+
272+
finally:
273+
self._response_streams.pop(request_id, None)
274+
await response_stream.aclose()
275+
await response_stream_reader.aclose()
274276

275277
async def send_notification(self, notification: SendNotificationT) -> None:
276278
"""

tests/client/test_resource_cleanup.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from unittest.mock import patch
2+
3+
import anyio
4+
import pytest
5+
6+
from mcp.shared.session import BaseSession
7+
from mcp.types import (
8+
ClientRequest,
9+
EmptyResult,
10+
PingRequest,
11+
)
12+
13+
14+
@pytest.mark.anyio
15+
async def test_send_request_stream_cleanup():
16+
"""
17+
Test that send_request properly cleans up streams when an exception occurs.
18+
19+
This test mocks out most of the session functionality to focus on stream cleanup.
20+
"""
21+
22+
# Create a mock session with the minimal required functionality
23+
class TestSession(BaseSession):
24+
async def _send_response(self, request_id, response):
25+
pass
26+
27+
# Create streams
28+
write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1)
29+
read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1)
30+
31+
# Create the session
32+
session = TestSession(
33+
read_stream_receive,
34+
write_stream_send,
35+
object, # Request type doesn't matter for this test
36+
object, # Notification type doesn't matter for this test
37+
)
38+
39+
# Create a test request
40+
request = ClientRequest(
41+
PingRequest(
42+
method="ping",
43+
)
44+
)
45+
46+
# Patch the _write_stream.send method to raise an exception
47+
async def mock_send(*args, **kwargs):
48+
raise RuntimeError("Simulated network error")
49+
50+
# Record the response streams before the test
51+
initial_stream_count = len(session._response_streams)
52+
53+
# Run the test with the patched method
54+
with patch.object(session._write_stream, "send", mock_send):
55+
with pytest.raises(RuntimeError):
56+
await session.send_request(request, EmptyResult)
57+
58+
# Verify that no response streams were leaked
59+
assert len(session._response_streams) == initial_stream_count, (
60+
f"Expected {initial_stream_count} response streams after request, "
61+
f"but found {len(session._response_streams)}"
62+
)
63+
64+
# Clean up
65+
await write_stream_send.aclose()
66+
await write_stream_receive.aclose()
67+
await read_stream_send.aclose()
68+
await read_stream_receive.aclose()

0 commit comments

Comments
 (0)