diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 0c05c6def..4541935a7 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -51,8 +51,8 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with anyio.create_task_group() as tg: - try: + try: + async with anyio.create_task_group() as tg: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) @@ -139,6 +139,14 @@ async def post_writer(endpoint_url: str): yield read_stream, write_stream finally: tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + except Exception as e: + logger.error(f"TaskGroup exception in SSE transport: {e}") + try: + await read_stream_writer.send(e) + except Exception: + logger.error(f"Failed to send TaskGroup exception to read stream: {e}") + raise + + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 39ac34d8a..aa48232d5 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -471,8 +471,8 @@ async def streamablehttp_client( read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - async with anyio.create_task_group() as tg: - try: + try: + async with anyio.create_task_group() as tg: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") async with httpx_client_factory( @@ -504,6 +504,13 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + except Exception as e: + logger.error(f"TaskGroup exception in StreamableHTTP transport: {e}") + try: + await read_stream_writer.send(e) + except Exception: + logger.error(f"Failed to send TaskGroup exception to read stream: {e}") + raise + finally: + await read_stream_writer.aclose() + await write_stream.aclose()