Skip to content

fix: implemented exception handling for client indefinite blocks #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2",
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
"exceptiongroup>=1.2.0",
]

[project.optional-dependencies]
Expand Down
211 changes: 112 additions & 99 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from exceptiongroup import BaseExceptionGroup, catch
from httpx_sse import aconnect_sse

import mcp.types as types
Expand All @@ -19,6 +20,11 @@ def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)


def handle_exception(exc: BaseExceptionGroup[Exception]) -> str:
"""Handle ExceptionGroup and Exceptions for Client transport for SSE"""
messages = "; ".join(str(e) for e in exc.exceptions)
raise Exception(f"TaskGroup failed with: {messages}") from None

@asynccontextmanager
async def sse_client(
url: str,
Expand All @@ -41,110 +47,117 @@ async def sse_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx.AsyncClient(headers=headers) as client:
async with aconnect_sse(
client,
"GET",
url,
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")

async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)

url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
with catch({Exception: handle_exception}):
async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx.AsyncClient(headers=headers) as client:
async with aconnect_sse(
client,
"GET",
url,
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")

async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)

task_status.started(endpoint_url)

case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc
!= endpoint_parsed.netloc
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)

task_status.started(endpoint_url)

case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(
"Received server message: "
f"{message}"

)
except Exception as exc:
logger.error(
"Error parsing server message: "
f"{exc}"
)
await read_stream_writer.send(exc)
continue

session_message = SessionMessage(message)
await read_stream_writer.send(
session_message
)
logger.debug(
f"Received server message: {message}"
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
)
except Exception as exc:
logger.error(
f"Error parsing server message: {exc}"
)
await read_stream_writer.send(exc)
continue

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()

async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(
f"Sending client message: {session_message}"
)
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
finally:
await write_stream.aclose()

endpoint_url = await tg.start(sse_reader)
logger.info(
f"Starting post writer with endpoint URL: {endpoint_url}"
)
tg.start_soon(post_writer, endpoint_url)

async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(
f"Sending client message: {session_message}"
)
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
yield read_stream, write_stream
finally:
await write_stream.aclose()

endpoint_url = await tg.start(sse_reader)
logger.info(
f"Starting post writer with endpoint URL: {endpoint_url}"
)
tg.start_soon(post_writer, endpoint_url)

try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading