@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
129
129
_read_stream_writer : MemoryObjectSendStream [SessionMessage | Exception ] | None = (
130
130
None
131
131
)
132
+ _read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ] | None = None
133
+ _write_stream : MemoryObjectSendStream [SessionMessage ] | None = None
132
134
_write_stream_reader : MemoryObjectReceiveStream [SessionMessage ] | None = None
133
135
134
136
def __init__ (
@@ -163,7 +165,11 @@ def __init__(
163
165
self .is_json_response_enabled = is_json_response_enabled
164
166
self ._event_store = event_store
165
167
self ._request_streams : dict [
166
- RequestId , MemoryObjectSendStream [EventMessage ]
168
+ RequestId ,
169
+ tuple [
170
+ MemoryObjectSendStream [EventMessage ],
171
+ MemoryObjectReceiveStream [EventMessage ],
172
+ ],
167
173
] = {}
168
174
self ._terminated = False
169
175
@@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
239
245
240
246
return event_data
241
247
248
+ async def _clean_up_memory_streams (self , request_id : RequestId ) -> None :
249
+ """Clean up memory streams for a given request ID."""
250
+ if request_id in self ._request_streams :
251
+ try :
252
+ # Close the request stream
253
+ await self ._request_streams [request_id ][0 ].aclose ()
254
+ await self ._request_streams [request_id ][1 ].aclose ()
255
+ except Exception as e :
256
+ logger .debug (f"Error closing memory streams: { e } " )
257
+ finally :
258
+ # Remove the request stream from the mapping
259
+ self ._request_streams .pop (request_id , None )
260
+
242
261
async def handle_request (self , scope : Scope , receive : Receive , send : Send ) -> None :
243
262
"""Application entry point that handles all HTTP requests"""
244
263
request = Request (scope , receive )
@@ -386,13 +405,11 @@ async def _handle_post_request(
386
405
387
406
# Extract the request ID outside the try block for proper scope
388
407
request_id = str (message .root .id )
389
- # Create promise stream for getting response
390
- request_stream_writer , request_stream_reader = (
391
- anyio .create_memory_object_stream [EventMessage ](0 )
392
- )
393
-
394
408
# Register this stream for the request ID
395
- self ._request_streams [request_id ] = request_stream_writer
409
+ self ._request_streams [request_id ] = anyio .create_memory_object_stream [
410
+ EventMessage
411
+ ](0 )
412
+ request_stream_reader = self ._request_streams [request_id ][1 ]
396
413
397
414
if self .is_json_response_enabled :
398
415
# Process the message
@@ -441,11 +458,7 @@ async def _handle_post_request(
441
458
)
442
459
await response (scope , receive , send )
443
460
finally :
444
- # Clean up the request stream
445
- if request_id in self ._request_streams :
446
- self ._request_streams .pop (request_id , None )
447
- await request_stream_reader .aclose ()
448
- await request_stream_writer .aclose ()
461
+ await self ._clean_up_memory_streams (request_id )
449
462
else :
450
463
# Create SSE stream
451
464
sse_stream_writer , sse_stream_reader = (
@@ -467,16 +480,12 @@ async def sse_writer():
467
480
event_message .message .root ,
468
481
JSONRPCResponse | JSONRPCError ,
469
482
):
470
- if request_id :
471
- self ._request_streams .pop (request_id , None )
472
483
break
473
484
except Exception as e :
474
485
logger .exception (f"Error in SSE writer: { e } " )
475
486
finally :
476
487
logger .debug ("Closing SSE writer" )
477
- # Clean up the request-specific streams
478
- if request_id and request_id in self ._request_streams :
479
- self ._request_streams .pop (request_id , None )
488
+ await self ._clean_up_memory_streams (request_id )
480
489
481
490
# Create and start EventSourceResponse
482
491
# SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ async def sse_writer():
507
516
await writer .send (session_message )
508
517
except Exception :
509
518
logger .exception ("SSE response error" )
510
- # Clean up the request stream if something goes wrong
511
- if request_id and request_id in self . _request_streams :
512
- self ._request_streams . pop (request_id , None )
519
+ await sse_stream_writer . aclose ()
520
+ await sse_stream_reader . aclose ()
521
+ await self ._clean_up_memory_streams (request_id )
513
522
514
523
except Exception as err :
515
524
logger .exception ("Error handling POST request" )
@@ -581,12 +590,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
581
590
async def standalone_sse_writer ():
582
591
try :
583
592
# Create a standalone message stream for server-initiated messages
584
- standalone_stream_writer , standalone_stream_reader = (
593
+
594
+ self ._request_streams [GET_STREAM_KEY ] = (
585
595
anyio .create_memory_object_stream [EventMessage ](0 )
586
596
)
587
-
588
- # Register this stream using the special key
589
- self ._request_streams [GET_STREAM_KEY ] = standalone_stream_writer
597
+ standalone_stream_reader = self ._request_streams [GET_STREAM_KEY ][1 ]
590
598
591
599
async with sse_stream_writer , standalone_stream_reader :
592
600
# Process messages from the standalone stream
@@ -603,8 +611,7 @@ async def standalone_sse_writer():
603
611
logger .exception (f"Error in standalone SSE writer: { e } " )
604
612
finally :
605
613
logger .debug ("Closing standalone SSE writer" )
606
- # Remove the stream from request_streams
607
- self ._request_streams .pop (GET_STREAM_KEY , None )
614
+ await self ._clean_up_memory_streams (GET_STREAM_KEY )
608
615
609
616
# Create and start EventSourceResponse
610
617
response = EventSourceResponse (
@@ -618,8 +625,9 @@ async def standalone_sse_writer():
618
625
await response (request .scope , request .receive , send )
619
626
except Exception as e :
620
627
logger .exception (f"Error in standalone SSE response: { e } " )
621
- # Clean up the request stream
622
- self ._request_streams .pop (GET_STREAM_KEY , None )
628
+ await sse_stream_writer .aclose ()
629
+ await sse_stream_reader .aclose ()
630
+ await self ._clean_up_memory_streams (GET_STREAM_KEY )
623
631
624
632
async def _handle_delete_request (self , request : Request , send : Send ) -> None :
625
633
"""Handle DELETE requests for explicit session termination."""
@@ -636,15 +644,15 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
636
644
if not await self ._validate_session (request , send ):
637
645
return
638
646
639
- self ._terminate_session ()
647
+ await self ._terminate_session ()
640
648
641
649
response = self ._create_json_response (
642
650
None ,
643
651
HTTPStatus .OK ,
644
652
)
645
653
await response (request .scope , request .receive , send )
646
654
647
- def _terminate_session (self ) -> None :
655
+ async def _terminate_session (self ) -> None :
648
656
"""Terminate the current session, closing all streams.
649
657
650
658
Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -656,19 +664,26 @@ def _terminate_session(self) -> None:
656
664
# We need a copy of the keys to avoid modification during iteration
657
665
request_stream_keys = list (self ._request_streams .keys ())
658
666
659
- # Close all request streams (synchronously)
667
+ # Close all request streams asynchronously
660
668
for key in request_stream_keys :
661
669
try :
662
- # Get the stream
663
- stream = self ._request_streams .get (key )
664
- if stream :
665
- # We must use close() here, not aclose() since this is a sync method
666
- stream .close ()
670
+ await self ._clean_up_memory_streams (key )
667
671
except Exception as e :
668
672
logger .debug (f"Error closing stream { key } during termination: { e } " )
669
673
670
674
# Clear the request streams dictionary immediately
671
675
self ._request_streams .clear ()
676
+ try :
677
+ if self ._read_stream_writer is not None :
678
+ await self ._read_stream_writer .aclose ()
679
+ if self ._read_stream is not None :
680
+ await self ._read_stream .aclose ()
681
+ if self ._write_stream_reader is not None :
682
+ await self ._write_stream_reader .aclose ()
683
+ if self ._write_stream is not None :
684
+ await self ._write_stream .aclose ()
685
+ except Exception as e :
686
+ logger .debug (f"Error closing streams: { e } " )
672
687
673
688
async def _handle_unsupported_request (self , request : Request , send : Send ) -> None :
674
689
"""Handle unsupported HTTP methods."""
@@ -756,10 +771,10 @@ async def send_event(event_message: EventMessage) -> None:
756
771
757
772
# If stream ID not in mapping, create it
758
773
if stream_id and stream_id not in self ._request_streams :
759
- msg_writer , msg_reader = anyio . create_memory_object_stream [
760
- EventMessage
761
- ]( 0 )
762
- self ._request_streams [stream_id ] = msg_writer
774
+ self . _request_streams [ stream_id ] = (
775
+ anyio . create_memory_object_stream [ EventMessage ]( 0 )
776
+ )
777
+ msg_reader = self ._request_streams [stream_id ][ 1 ]
763
778
764
779
# Forward messages to SSE
765
780
async with msg_reader :
@@ -781,6 +796,9 @@ async def send_event(event_message: EventMessage) -> None:
781
796
await response (request .scope , request .receive , send )
782
797
except Exception as e :
783
798
logger .exception (f"Error in replay response: { e } " )
799
+ finally :
800
+ await sse_stream_writer .aclose ()
801
+ await sse_stream_reader .aclose ()
784
802
785
803
except Exception as e :
786
804
logger .exception (f"Error replaying events: { e } " )
@@ -818,7 +836,9 @@ async def connect(
818
836
819
837
# Store the streams
820
838
self ._read_stream_writer = read_stream_writer
839
+ self ._read_stream = read_stream
821
840
self ._write_stream_reader = write_stream_reader
841
+ self ._write_stream = write_stream
822
842
823
843
# Start a task group for message routing
824
844
async with anyio .create_task_group () as tg :
@@ -863,7 +883,7 @@ async def message_router():
863
883
if request_stream_id in self ._request_streams :
864
884
try :
865
885
# Send both the message and the event ID
866
- await self ._request_streams [request_stream_id ].send (
886
+ await self ._request_streams [request_stream_id ][ 0 ] .send (
867
887
EventMessage (message , event_id )
868
888
)
869
889
except (
@@ -872,6 +892,12 @@ async def message_router():
872
892
):
873
893
# Stream might be closed, remove from registry
874
894
self ._request_streams .pop (request_stream_id , None )
895
+ else :
896
+ logging .debug (
897
+ f"""Request stream { request_stream_id } not found
898
+ for message. Still processing message as the client
899
+ might reconnect and replay."""
900
+ )
875
901
except Exception as e :
876
902
logger .exception (f"Error in message router: { e } " )
877
903
@@ -882,9 +908,19 @@ async def message_router():
882
908
# Yield the streams for the caller to use
883
909
yield read_stream , write_stream
884
910
finally :
885
- for stream in list (self ._request_streams .values ()):
911
+ for stream_id in list (self ._request_streams .keys ()):
886
912
try :
887
- await stream .aclose ()
888
- except Exception :
913
+ await self ._clean_up_memory_streams (stream_id )
914
+ except Exception as e :
915
+ logger .debug (f"Error closing request stream: { e } " )
889
916
pass
890
917
self ._request_streams .clear ()
918
+
919
+ # Clean up the read and write streams
920
+ try :
921
+ await read_stream_writer .aclose ()
922
+ await read_stream .aclose ()
923
+ await write_stream_reader .aclose ()
924
+ await write_stream .aclose ()
925
+ except Exception as e :
926
+ logger .debug (f"Error closing streams: { e } " )
0 commit comments