From c202846659038f588ac3af68faf291436b4ab451 Mon Sep 17 00:00:00 2001 From: kavinkumarbaskar <61575461+kavinkumar807@users.noreply.github.com> Date: Sun, 13 Apr 2025 20:50:38 +0530 Subject: [PATCH 1/4] fix: implemented exception handling for client indefinite blocks - propagated the error once value error is raised instead of send it to stream - added exception groups to handle value error as well as connection error etc with exception groups - added errors list to add the exception groups and convert it to normal exception and throw to client --- src/mcp/client/sse.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a7..511bc2ec 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -40,6 +40,8 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + errors: list[Exception] = [] + async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") @@ -104,7 +106,7 @@ async def sse_reader( ) except Exception as exc: logger.error(f"Error in sse_reader: {exc}") - await read_stream_writer.send(exc) + raise finally: await read_stream_writer.aclose() @@ -141,6 +143,12 @@ async def post_writer(endpoint_url: str): yield read_stream, write_stream finally: tg.cancel_scope.cancel() + except* ValueError as eg: + errors.extend(eg.exceptions) + except* Exception as eg: + errors.extend(eg.exceptions) finally: await read_stream_writer.aclose() await write_stream.aclose() + if errors: + raise Exception("TaskGroup failed with: " + " ".join([str(e) for e in errors])) From c2dc97fc092398de9be4b20eb8da60c72d486e70 Mon Sep 17 00:00:00 2001 From: kavinkumarbaskar <61575461+kavinkumar807@users.noreply.github.com> Date: Mon, 14 Apr 2025 11:06:42 +0530 Subject: [PATCH 2/4] fix: except* python 3.10 incompatibility issue * added exceptiongroup dependency in toml for handling exception group issue from python 3.10 to python 3.13 * added handle_exception function to handle the exceptions * updated the uv.lock --- pyproject.toml | 1 + src/mcp/client/sse.py | 219 ++++++++++++++++++++++-------------------- uv.lock | 4 +- 3 files changed, 118 insertions(+), 106 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 25514cd6..269c3246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", + "exceptiongroup>=1.2.0", ] [project.optional-dependencies] diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 511bc2ec..1fb56240 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,6 +7,7 @@ import httpx from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from exceptiongroup import ExceptionGroup, catch from httpx_sse import aconnect_sse import mcp.types as types @@ -18,6 +19,14 @@ def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) +def handle_exception(exc: Exception) -> str: + """Handle ExceptionGroup and Exceptions for Client transport for SSE""" + if isinstance(exc, ExceptionGroup): + messages = "; ".join(str(e) for e in exc.exceptions) + raise Exception(f"TaskGroup failed with: {messages}") from None + else: + raise Exception(f"TaskGroup failed with: {exc}") from None + @asynccontextmanager async def sse_client( url: str, @@ -40,115 +49,115 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - errors: list[Exception] = [] - - async with anyio.create_task_group() as tg: - try: - logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: - async with aconnect_sse( - client, - "GET", - url, - timeout=httpx.Timeout(timeout, read=sse_read_timeout), - ) as event_source: - event_source.response.raise_for_status() - logger.debug("SSE connection established") - - async def sse_reader( - task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, - ): - try: - async for sse in event_source.aiter_sse(): - logger.debug(f"Received SSE event: {sse.event}") - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.info( - f"Received endpoint URL: {endpoint_url}" - ) - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme - != endpoint_parsed.scheme - ): - error_msg = ( - "Endpoint origin does not match " - f"connection origin: {endpoint_url}" + with catch({ + Exception: handle_exception, + }): + async with anyio.create_task_group() as tg: + try: + logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx.AsyncClient(headers=headers) as client: + async with aconnect_sse( + client, + "GET", + url, + timeout=httpx.Timeout(timeout, read=sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + async def sse_reader( + task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, + ): + try: + async for sse in event_source.aiter_sse(): + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.info( + f"Received endpoint URL: {endpoint_url}" ) - logger.error(error_msg) - raise ValueError(error_msg) - - task_status.started(endpoint_url) - case "message": - try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 - sse.data - ) - logger.debug( - f"Received server message: {message}" + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( + url_parsed.netloc + != endpoint_parsed.netloc + or url_parsed.scheme + != endpoint_parsed.scheme + ): + error_msg = ( + "Endpoint origin does not match " + f"connection origin: {endpoint_url}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + task_status.started(endpoint_url) + + case "message": + try: + message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data + ) + logger.debug( + f"Received server message: " + f"{message}" + ) + except Exception as exc: + logger.error( + f"Error parsing server message: " + f"{exc}" + ) + await read_stream_writer.send(exc) + continue + + await read_stream_writer.send(message) + case _: + logger.warning( + f"Unknown SSE event: {sse.event}" ) - except Exception as exc: - logger.error( - f"Error parsing server message: {exc}" - ) - await read_stream_writer.send(exc) - continue - - await read_stream_writer.send(message) - case _: - logger.warning( - f"Unknown SSE event: {sse.event}" + except Exception as exc: + logger.error(f"Error in sse_reader: {exc}") + raise + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader: + async for message in write_stream_reader: + logger.debug( + f"Sending client message: {message}" ) - except Exception as exc: - logger.error(f"Error in sse_reader: {exc}") - raise - finally: - await read_stream_writer.aclose() + response = await client.post( + endpoint_url, + json=message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + ) + response.raise_for_status() + logger.debug( + "Client message sent successfully: " + f"{response.status_code}" + ) + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + finally: + await write_stream.aclose() + + endpoint_url = await tg.start(sse_reader) + logger.info( + f"Starting post writer with endpoint URL: {endpoint_url}" + ) + tg.start_soon(post_writer, endpoint_url) - async def post_writer(endpoint_url: str): try: - async with write_stream_reader: - async for message in write_stream_reader: - logger.debug(f"Sending client message: {message}") - response = await client.post( - endpoint_url, - json=message.model_dump( - by_alias=True, - mode="json", - exclude_none=True, - ), - ) - response.raise_for_status() - logger.debug( - "Client message sent successfully: " - f"{response.status_code}" - ) - except Exception as exc: - logger.error(f"Error in post_writer: {exc}") + yield read_stream, write_stream finally: - await write_stream.aclose() - - endpoint_url = await tg.start(sse_reader) - logger.info( - f"Starting post writer with endpoint URL: {endpoint_url}" - ) - tg.start_soon(post_writer, endpoint_url) - - try: - yield read_stream, write_stream - finally: - tg.cancel_scope.cancel() - except* ValueError as eg: - errors.extend(eg.exceptions) - except* Exception as eg: - errors.extend(eg.exceptions) - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - if errors: - raise Exception("TaskGroup failed with: " + " ".join([str(e) for e in errors])) + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/uv.lock b/uv.lock index 424e2d48..28c6e58b 100644 --- a/uv.lock +++ b/uv.lock @@ -490,6 +490,7 @@ name = "mcp" source = { editable = "." } dependencies = [ { name = "anyio" }, + { name = "exceptiongroup" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "pydantic" }, @@ -531,6 +532,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.5" }, + { name = "exceptiongroup", specifier = ">=1.2.0" }, { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, @@ -1618,4 +1620,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/fb/c492d6daa5ec067c2988ac80c61359ace5c4c674c532985ac5a123436cec/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04", size = 174155 }, { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884 }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, -] \ No newline at end of file +] From ea74ee7c59f3127e1463e8de25b3104ff6b4fdc5 Mon Sep 17 00:00:00 2001 From: kavinkumarbaskar <61575461+kavinkumar807@users.noreply.github.com> Date: Mon, 14 Apr 2025 11:31:13 +0530 Subject: [PATCH 3/4] fix: type check issue for Exception - changed the handle_exception functions implementation - replaced the Exception parameter with BaseExceptionGroup[Exception] --- src/mcp/client/sse.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 1fb56240..e9e5e13f 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,7 +7,7 @@ import httpx from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from exceptiongroup import ExceptionGroup, catch +from exceptiongroup import BaseExceptionGroup, catch from httpx_sse import aconnect_sse import mcp.types as types @@ -19,13 +19,10 @@ def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) -def handle_exception(exc: Exception) -> str: +def handle_exception(exc: BaseExceptionGroup[Exception]) -> str: """Handle ExceptionGroup and Exceptions for Client transport for SSE""" - if isinstance(exc, ExceptionGroup): - messages = "; ".join(str(e) for e in exc.exceptions) - raise Exception(f"TaskGroup failed with: {messages}") from None - else: - raise Exception(f"TaskGroup failed with: {exc}") from None + messages = "; ".join(str(e) for e in exc.exceptions) + raise Exception(f"TaskGroup failed with: {messages}") from None @asynccontextmanager async def sse_client( @@ -50,7 +47,7 @@ async def sse_client( write_stream, write_stream_reader = anyio.create_memory_object_stream(0) with catch({ - Exception: handle_exception, + Exception: handle_exception }): async with anyio.create_task_group() as tg: try: From 7d2df66220e841cd537cbbd63976862a228e10f9 Mon Sep 17 00:00:00 2001 From: kavinkumarbaskar Date: Fri, 2 May 2025 20:36:25 +0530 Subject: [PATCH 4/4] fix lint issue --- src/mcp/client/sse.py | 85 ++++++++++--------------------------------- 1 file changed, 19 insertions(+), 66 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 0e134df9..b05191f9 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -47,9 +47,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - with catch({ - Exception: handle_exception - }): + with catch({Exception: handle_exception}): async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") @@ -99,72 +97,46 @@ async def sse_reader( sse.data ) logger.debug( - f"Received server message: " + "Received server message: " f"{message}" + ) except Exception as exc: logger.error( - f"Error parsing server message: " + "Error parsing server message: " f"{exc}" ) await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send( + session_message + ) case _: logger.warning( f"Unknown SSE event: {sse.event}" ) except Exception as exc: logger.error(f"Error in sse_reader: {exc}") - raise + await read_stream_writer.send(exc) finally: await read_stream_writer.aclose() async def post_writer(endpoint_url: str): try: async with write_stream_reader: - async for message in write_stream_reader: + async for session_message in write_stream_reader: logger.debug( - f"Sending client message: {message}" + f"Sending client message: {session_message}" ) - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme - != endpoint_parsed.scheme - ): - error_msg = ( - "Endpoint origin does not match " - f"connection origin: {endpoint_url}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - - task_status.started(endpoint_url) - - case "message": - try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 - sse.data - ) - logger.debug( - f"Received server message: {message}" - ) - except Exception as exc: - logger.error( - f"Error parsing server message: {exc}" - ) - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - case _: - logger.warning( - f"Unknown SSE event: {sse.event}" + response = await client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), ) response.raise_for_status() logger.debug( @@ -183,26 +155,7 @@ async def post_writer(endpoint_url: str): tg.start_soon(post_writer, endpoint_url) try: - async with write_stream_reader: - async for session_message in write_stream_reader: - logger.debug( - f"Sending client message: {session_message}" - ) - response = await client.post( - endpoint_url, - json=session_message.message.model_dump( - by_alias=True, - mode="json", - exclude_none=True, - ), - ) - response.raise_for_status() - logger.debug( - "Client message sent successfully: " - f"{response.status_code}" - ) - except Exception as exc: - logger.error(f"Error in post_writer: {exc}") + yield read_stream, write_stream finally: tg.cancel_scope.cancel() finally: