Skip to content

Commit 5d8eaf7

Browse files
authored
Streamable Http - clean up server memory streams (#604)
1 parent 74f5fcf commit 5d8eaf7

File tree

4 files changed

+108
-69
lines changed

4 files changed

+108
-69
lines changed

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -185,20 +185,22 @@ async def handle_streamable_http(scope, receive, send):
185185
)
186186
server_instances[http_transport.mcp_session_id] = http_transport
187187
logger.info(f"Created new transport with session ID: {new_session_id}")
188-
async with http_transport.connect() as streams:
189-
read_stream, write_stream = streams
190188

191-
async def run_server():
192-
await app.run(
193-
read_stream,
194-
write_stream,
195-
app.create_initialization_options(),
196-
)
189+
async def run_server(task_status=None):
190+
async with http_transport.connect() as streams:
191+
read_stream, write_stream = streams
192+
if task_status:
193+
task_status.started()
194+
await app.run(
195+
read_stream,
196+
write_stream,
197+
app.create_initialization_options(),
198+
)
197199

198200
if not task_group:
199201
raise RuntimeError("Task group is not initialized")
200202

201-
task_group.start_soon(run_server)
203+
await task_group.start(run_server)
202204

203205
# Handle the HTTP request and return the response
204206
await http_transport.handle_request(scope, receive, send)

src/mcp/server/lowlevel/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ async def run(
480480
# but also make tracing exceptions much easier during testing and when using
481481
# in-process servers.
482482
raise_exceptions: bool = False,
483-
# When True, the server as stateless deployments where
483+
# When True, the server is stateless and
484484
# clients can perform initialization with any node. The client must still follow
485485
# the initialization lifecycle, but can do so with any available node
486486
# rather than requiring initialization for each connection.

src/mcp/server/streamable_http.py

+80-44
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
129129
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
130130
None
131131
)
132+
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
133+
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
132134
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
133135

134136
def __init__(
@@ -163,7 +165,11 @@ def __init__(
163165
self.is_json_response_enabled = is_json_response_enabled
164166
self._event_store = event_store
165167
self._request_streams: dict[
166-
RequestId, MemoryObjectSendStream[EventMessage]
168+
RequestId,
169+
tuple[
170+
MemoryObjectSendStream[EventMessage],
171+
MemoryObjectReceiveStream[EventMessage],
172+
],
167173
] = {}
168174
self._terminated = False
169175

@@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
239245

240246
return event_data
241247

248+
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
249+
"""Clean up memory streams for a given request ID."""
250+
if request_id in self._request_streams:
251+
try:
252+
# Close the request stream
253+
await self._request_streams[request_id][0].aclose()
254+
await self._request_streams[request_id][1].aclose()
255+
except Exception as e:
256+
logger.debug(f"Error closing memory streams: {e}")
257+
finally:
258+
# Remove the request stream from the mapping
259+
self._request_streams.pop(request_id, None)
260+
242261
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
243262
"""Application entry point that handles all HTTP requests"""
244263
request = Request(scope, receive)
@@ -386,13 +405,11 @@ async def _handle_post_request(
386405

387406
# Extract the request ID outside the try block for proper scope
388407
request_id = str(message.root.id)
389-
# Create promise stream for getting response
390-
request_stream_writer, request_stream_reader = (
391-
anyio.create_memory_object_stream[EventMessage](0)
392-
)
393-
394408
# Register this stream for the request ID
395-
self._request_streams[request_id] = request_stream_writer
409+
self._request_streams[request_id] = anyio.create_memory_object_stream[
410+
EventMessage
411+
](0)
412+
request_stream_reader = self._request_streams[request_id][1]
396413

397414
if self.is_json_response_enabled:
398415
# Process the message
@@ -441,11 +458,7 @@ async def _handle_post_request(
441458
)
442459
await response(scope, receive, send)
443460
finally:
444-
# Clean up the request stream
445-
if request_id in self._request_streams:
446-
self._request_streams.pop(request_id, None)
447-
await request_stream_reader.aclose()
448-
await request_stream_writer.aclose()
461+
await self._clean_up_memory_streams(request_id)
449462
else:
450463
# Create SSE stream
451464
sse_stream_writer, sse_stream_reader = (
@@ -467,16 +480,12 @@ async def sse_writer():
467480
event_message.message.root,
468481
JSONRPCResponse | JSONRPCError,
469482
):
470-
if request_id:
471-
self._request_streams.pop(request_id, None)
472483
break
473484
except Exception as e:
474485
logger.exception(f"Error in SSE writer: {e}")
475486
finally:
476487
logger.debug("Closing SSE writer")
477-
# Clean up the request-specific streams
478-
if request_id and request_id in self._request_streams:
479-
self._request_streams.pop(request_id, None)
488+
await self._clean_up_memory_streams(request_id)
480489

481490
# Create and start EventSourceResponse
482491
# SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ async def sse_writer():
507516
await writer.send(session_message)
508517
except Exception:
509518
logger.exception("SSE response error")
510-
# Clean up the request stream if something goes wrong
511-
if request_id and request_id in self._request_streams:
512-
self._request_streams.pop(request_id, None)
519+
await sse_stream_writer.aclose()
520+
await sse_stream_reader.aclose()
521+
await self._clean_up_memory_streams(request_id)
513522

514523
except Exception as err:
515524
logger.exception("Error handling POST request")
@@ -581,12 +590,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
581590
async def standalone_sse_writer():
582591
try:
583592
# Create a standalone message stream for server-initiated messages
584-
standalone_stream_writer, standalone_stream_reader = (
593+
594+
self._request_streams[GET_STREAM_KEY] = (
585595
anyio.create_memory_object_stream[EventMessage](0)
586596
)
587-
588-
# Register this stream using the special key
589-
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
597+
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
590598

591599
async with sse_stream_writer, standalone_stream_reader:
592600
# Process messages from the standalone stream
@@ -603,8 +611,7 @@ async def standalone_sse_writer():
603611
logger.exception(f"Error in standalone SSE writer: {e}")
604612
finally:
605613
logger.debug("Closing standalone SSE writer")
606-
# Remove the stream from request_streams
607-
self._request_streams.pop(GET_STREAM_KEY, None)
614+
await self._clean_up_memory_streams(GET_STREAM_KEY)
608615

609616
# Create and start EventSourceResponse
610617
response = EventSourceResponse(
@@ -618,8 +625,9 @@ async def standalone_sse_writer():
618625
await response(request.scope, request.receive, send)
619626
except Exception as e:
620627
logger.exception(f"Error in standalone SSE response: {e}")
621-
# Clean up the request stream
622-
self._request_streams.pop(GET_STREAM_KEY, None)
628+
await sse_stream_writer.aclose()
629+
await sse_stream_reader.aclose()
630+
await self._clean_up_memory_streams(GET_STREAM_KEY)
623631

624632
async def _handle_delete_request(self, request: Request, send: Send) -> None:
625633
"""Handle DELETE requests for explicit session termination."""
@@ -636,15 +644,15 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
636644
if not await self._validate_session(request, send):
637645
return
638646

639-
self._terminate_session()
647+
await self._terminate_session()
640648

641649
response = self._create_json_response(
642650
None,
643651
HTTPStatus.OK,
644652
)
645653
await response(request.scope, request.receive, send)
646654

647-
def _terminate_session(self) -> None:
655+
async def _terminate_session(self) -> None:
648656
"""Terminate the current session, closing all streams.
649657
650658
Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -656,19 +664,26 @@ def _terminate_session(self) -> None:
656664
# We need a copy of the keys to avoid modification during iteration
657665
request_stream_keys = list(self._request_streams.keys())
658666

659-
# Close all request streams (synchronously)
667+
# Close all request streams asynchronously
660668
for key in request_stream_keys:
661669
try:
662-
# Get the stream
663-
stream = self._request_streams.get(key)
664-
if stream:
665-
# We must use close() here, not aclose() since this is a sync method
666-
stream.close()
670+
await self._clean_up_memory_streams(key)
667671
except Exception as e:
668672
logger.debug(f"Error closing stream {key} during termination: {e}")
669673

670674
# Clear the request streams dictionary immediately
671675
self._request_streams.clear()
676+
try:
677+
if self._read_stream_writer is not None:
678+
await self._read_stream_writer.aclose()
679+
if self._read_stream is not None:
680+
await self._read_stream.aclose()
681+
if self._write_stream_reader is not None:
682+
await self._write_stream_reader.aclose()
683+
if self._write_stream is not None:
684+
await self._write_stream.aclose()
685+
except Exception as e:
686+
logger.debug(f"Error closing streams: {e}")
672687

673688
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
674689
"""Handle unsupported HTTP methods."""
@@ -756,10 +771,10 @@ async def send_event(event_message: EventMessage) -> None:
756771

757772
# If stream ID not in mapping, create it
758773
if stream_id and stream_id not in self._request_streams:
759-
msg_writer, msg_reader = anyio.create_memory_object_stream[
760-
EventMessage
761-
](0)
762-
self._request_streams[stream_id] = msg_writer
774+
self._request_streams[stream_id] = (
775+
anyio.create_memory_object_stream[EventMessage](0)
776+
)
777+
msg_reader = self._request_streams[stream_id][1]
763778

764779
# Forward messages to SSE
765780
async with msg_reader:
@@ -781,6 +796,9 @@ async def send_event(event_message: EventMessage) -> None:
781796
await response(request.scope, request.receive, send)
782797
except Exception as e:
783798
logger.exception(f"Error in replay response: {e}")
799+
finally:
800+
await sse_stream_writer.aclose()
801+
await sse_stream_reader.aclose()
784802

785803
except Exception as e:
786804
logger.exception(f"Error replaying events: {e}")
@@ -818,7 +836,9 @@ async def connect(
818836

819837
# Store the streams
820838
self._read_stream_writer = read_stream_writer
839+
self._read_stream = read_stream
821840
self._write_stream_reader = write_stream_reader
841+
self._write_stream = write_stream
822842

823843
# Start a task group for message routing
824844
async with anyio.create_task_group() as tg:
@@ -863,7 +883,7 @@ async def message_router():
863883
if request_stream_id in self._request_streams:
864884
try:
865885
# Send both the message and the event ID
866-
await self._request_streams[request_stream_id].send(
886+
await self._request_streams[request_stream_id][0].send(
867887
EventMessage(message, event_id)
868888
)
869889
except (
@@ -872,6 +892,12 @@ async def message_router():
872892
):
873893
# Stream might be closed, remove from registry
874894
self._request_streams.pop(request_stream_id, None)
895+
else:
896+
logging.debug(
897+
f"""Request stream {request_stream_id} not found
898+
for message. Still processing message as the client
899+
might reconnect and replay."""
900+
)
875901
except Exception as e:
876902
logger.exception(f"Error in message router: {e}")
877903

@@ -882,9 +908,19 @@ async def message_router():
882908
# Yield the streams for the caller to use
883909
yield read_stream, write_stream
884910
finally:
885-
for stream in list(self._request_streams.values()):
911+
for stream_id in list(self._request_streams.keys()):
886912
try:
887-
await stream.aclose()
888-
except Exception:
913+
await self._clean_up_memory_streams(stream_id)
914+
except Exception as e:
915+
logger.debug(f"Error closing request stream: {e}")
889916
pass
890917
self._request_streams.clear()
918+
919+
# Clean up the read and write streams
920+
try:
921+
await read_stream_writer.aclose()
922+
await read_stream.aclose()
923+
await write_stream_reader.aclose()
924+
await write_stream.aclose()
925+
except Exception as e:
926+
logger.debug(f"Error closing streams: {e}")

tests/shared/test_streamable_http.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -234,29 +234,30 @@ async def handle_streamable_http(scope, receive, send):
234234
event_store=event_store,
235235
)
236236

237-
async with http_transport.connect() as streams:
238-
read_stream, write_stream = streams
239-
240-
async def run_server():
237+
async def run_server(task_status=None):
238+
async with http_transport.connect() as streams:
239+
read_stream, write_stream = streams
240+
if task_status:
241+
task_status.started()
241242
await server.run(
242243
read_stream,
243244
write_stream,
244245
server.create_initialization_options(),
245246
)
246247

247-
if task_group is None:
248-
response = Response(
249-
"Internal Server Error: Task group is not initialized",
250-
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
251-
)
252-
await response(scope, receive, send)
253-
return
248+
if task_group is None:
249+
response = Response(
250+
"Internal Server Error: Task group is not initialized",
251+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
252+
)
253+
await response(scope, receive, send)
254+
return
254255

255-
# Store the instance before starting the task to prevent races
256-
server_instances[http_transport.mcp_session_id] = http_transport
257-
task_group.start_soon(run_server)
256+
# Store the instance before starting the task to prevent races
257+
server_instances[http_transport.mcp_session_id] = http_transport
258+
await task_group.start(run_server)
258259

259-
await http_transport.handle_request(scope, receive, send)
260+
await http_transport.handle_request(scope, receive, send)
260261
else:
261262
response = Response(
262263
"Bad Request: No valid session ID provided",

0 commit comments

Comments
 (0)