Skip to content

Commit 83968b5

Browse files
authored
Handle SSE Disconnects Properly (#612)
1 parent 5d8eaf7 commit 83968b5

File tree

7 files changed

+38
-11
lines changed

7 files changed

+38
-11
lines changed

examples/servers/simple-prompt/mcp_simple_prompt/server.py

+2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def get_prompt(
9090
if transport == "sse":
9191
from mcp.server.sse import SseServerTransport
9292
from starlette.applications import Starlette
93+
from starlette.responses import Response
9394
from starlette.routing import Mount, Route
9495

9596
sse = SseServerTransport("/messages/")
@@ -101,6 +102,7 @@ async def handle_sse(request):
101102
await app.run(
102103
streams[0], streams[1], app.create_initialization_options()
103104
)
105+
return Response()
104106

105107
starlette_app = Starlette(
106108
debug=True,

examples/servers/simple-resource/mcp_simple_resource/server.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def read_resource(uri: FileUrl) -> str | bytes:
4646
if transport == "sse":
4747
from mcp.server.sse import SseServerTransport
4848
from starlette.applications import Starlette
49+
from starlette.responses import Response
4950
from starlette.routing import Mount, Route
5051

5152
sse = SseServerTransport("/messages/")
@@ -57,11 +58,12 @@ async def handle_sse(request):
5758
await app.run(
5859
streams[0], streams[1], app.create_initialization_options()
5960
)
61+
return Response()
6062

6163
starlette_app = Starlette(
6264
debug=True,
6365
routes=[
64-
Route("/sse", endpoint=handle_sse),
66+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
6567
Mount("/messages/", app=sse.handle_post_message),
6668
],
6769
)

examples/servers/simple-tool/mcp_simple_tool/server.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ async def list_tools() -> list[types.Tool]:
6060
if transport == "sse":
6161
from mcp.server.sse import SseServerTransport
6262
from starlette.applications import Starlette
63+
from starlette.responses import Response
6364
from starlette.routing import Mount, Route
6465

6566
sse = SseServerTransport("/messages/")
@@ -71,11 +72,12 @@ async def handle_sse(request):
7172
await app.run(
7273
streams[0], streams[1], app.create_initialization_options()
7374
)
75+
return Response()
7476

7577
starlette_app = Starlette(
7678
debug=True,
7779
routes=[
78-
Route("/sse", endpoint=handle_sse),
80+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
7981
Mount("/messages/", app=sse.handle_post_message),
8082
],
8183
)

src/mcp/server/fastmcp/server.py

+1
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
589589
streams[1],
590590
self._mcp_server.create_initialization_options(),
591591
)
592+
return Response()
592593

593594
# Create routes
594595
routes: list[Route | Mount] = []

src/mcp/server/session.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,6 @@ def __init__(
104104
self._exit_stack.push_async_callback(
105105
lambda: self._incoming_message_stream_reader.aclose()
106106
)
107-
self._exit_stack.push_async_callback(
108-
lambda: self._incoming_message_stream_writer.aclose()
109-
)
110107

111108
@property
112109
def client_params(self) -> types.InitializeRequestParams | None:
@@ -144,6 +141,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
144141

145142
return True
146143

144+
async def _receive_loop(self) -> None:
145+
async with self._incoming_message_stream_writer:
146+
await super()._receive_loop()
147+
147148
async def _received_request(
148149
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
149150
):

src/mcp/server/sse.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
# Create Starlette routes for SSE and message handling
1212
routes = [
13-
Route("/sse", endpoint=handle_sse),
13+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
1414
Mount("/messages/", app=sse.handle_post_message),
1515
]
1616
@@ -22,12 +22,18 @@ async def handle_sse(request):
2222
await app.run(
2323
streams[0], streams[1], app.create_initialization_options()
2424
)
25+
# Return empty response to avoid NoneType error
26+
return Response()
2527
2628
# Create and run Starlette app
2729
starlette_app = Starlette(routes=routes)
2830
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
2931
```
3032
33+
Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType'
34+
object is not callable" error when client disconnects. The example above returns
35+
an empty Response() after the SSE connection ends to fix this.
36+
3137
See SseServerTransport class documentation for more details.
3238
"""
3339

@@ -120,11 +126,22 @@ async def sse_writer():
120126
)
121127

122128
async with anyio.create_task_group() as tg:
123-
response = EventSourceResponse(
124-
content=sse_stream_reader, data_sender_callable=sse_writer
125-
)
129+
130+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
131+
"""
132+
The EventSourceResponse returning signals a client close / disconnect.
133+
In this case we close our side of the streams to signal the client that
134+
the connection has been closed.
135+
"""
136+
await EventSourceResponse(
137+
content=sse_stream_reader, data_sender_callable=sse_writer
138+
)(scope, receive, send)
139+
await read_stream_writer.aclose()
140+
await write_stream_reader.aclose()
141+
logging.debug(f"Client session disconnected {session_id}")
142+
126143
logger.debug("Starting SSE response task")
127-
tg.start_soon(response, scope, receive, send)
144+
tg.start_soon(response_wrapper, scope, receive, send)
128145

129146
logger.debug("Yielding read and write streams")
130147
yield (read_stream, write_stream)

tests/shared/test_sse.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydantic import AnyUrl
1111
from starlette.applications import Starlette
1212
from starlette.requests import Request
13+
from starlette.responses import Response
1314
from starlette.routing import Mount, Route
1415

1516
from mcp.client.session import ClientSession
@@ -83,13 +84,14 @@ def make_server_app() -> Starlette:
8384
sse = SseServerTransport("/messages/")
8485
server = ServerTest()
8586

86-
async def handle_sse(request: Request) -> None:
87+
async def handle_sse(request: Request) -> Response:
8788
async with sse.connect_sse(
8889
request.scope, request.receive, request._send
8990
) as streams:
9091
await server.run(
9192
streams[0], streams[1], server.create_initialization_options()
9293
)
94+
return Response()
9395

9496
app = Starlette(
9597
routes=[

0 commit comments

Comments
 (0)