1
+ using Microsoft . Extensions . Logging ;
2
+ using Microsoft . Extensions . Logging . Abstractions ;
3
+ using ModelContextProtocol . Protocol ;
4
+ using System . Net ;
5
+ using System . Threading . Channels ;
6
+
7
+ namespace ModelContextProtocol . Client ;
8
+
9
+ /// <summary>
10
+ /// A transport that automatically detects whether to use Streamable HTTP or SSE transport
11
+ /// by trying Streamable HTTP first and falling back to SSE if that fails.
12
+ /// </summary>
13
+ internal sealed partial class AutoDetectingClientSessionTransport : ITransport
14
+ {
15
+ private readonly SseClientTransportOptions _options ;
16
+ private readonly HttpClient _httpClient ;
17
+ private readonly ILoggerFactory ? _loggerFactory ;
18
+ private readonly ILogger _logger ;
19
+ private readonly string _name ;
20
+ private readonly Channel < JsonRpcMessage > _messageChannel ;
21
+
22
+ public AutoDetectingClientSessionTransport ( SseClientTransportOptions transportOptions , HttpClient httpClient , ILoggerFactory ? loggerFactory , string endpointName )
23
+ {
24
+ Throw . IfNull ( transportOptions ) ;
25
+ Throw . IfNull ( httpClient ) ;
26
+
27
+ _options = transportOptions ;
28
+ _httpClient = httpClient ;
29
+ _loggerFactory = loggerFactory ;
30
+ _logger = ( ILogger ? ) loggerFactory ? . CreateLogger < AutoDetectingClientSessionTransport > ( ) ?? NullLogger . Instance ;
31
+ _name = endpointName ;
32
+
33
+ // Same as TransportBase.cs.
34
+ _messageChannel = Channel . CreateUnbounded < JsonRpcMessage > ( new UnboundedChannelOptions
35
+ {
36
+ SingleReader = true ,
37
+ SingleWriter = false ,
38
+ } ) ;
39
+ }
40
+
41
+ /// <summary>
42
+ /// Returns the active transport (either StreamableHttp or SSE)
43
+ /// </summary>
44
+ internal ITransport ? ActiveTransport { get ; private set ; }
45
+
46
+ public ChannelReader < JsonRpcMessage > MessageReader => _messageChannel . Reader ;
47
+
48
+ /// <inheritdoc/>
49
+ public Task SendMessageAsync ( JsonRpcMessage message , CancellationToken cancellationToken = default )
50
+ {
51
+ if ( ActiveTransport is null )
52
+ {
53
+ return InitializeAsync ( message , cancellationToken ) ;
54
+ }
55
+
56
+ return ActiveTransport . SendMessageAsync ( message , cancellationToken ) ;
57
+ }
58
+
59
+ private async Task InitializeAsync ( JsonRpcMessage message , CancellationToken cancellationToken )
60
+ {
61
+ // Try StreamableHttp first
62
+ var streamableHttpTransport = new StreamableHttpClientSessionTransport ( _name , _options , _httpClient , _messageChannel , _loggerFactory ) ;
63
+
64
+ try
65
+ {
66
+ LogAttemptingStreamableHttp ( _name ) ;
67
+ using var response = await streamableHttpTransport . SendHttpRequestAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
68
+
69
+ if ( response . IsSuccessStatusCode )
70
+ {
71
+ LogUsingStreamableHttp ( _name ) ;
72
+ ActiveTransport = streamableHttpTransport ;
73
+ }
74
+ else
75
+ {
76
+ // If the status code is not success, fall back to SSE
77
+ LogStreamableHttpFailed ( _name , response . StatusCode ) ;
78
+
79
+ await streamableHttpTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
80
+ await InitializeSseTransportAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
81
+ }
82
+ }
83
+ catch
84
+ {
85
+ // If nothing threw inside the try block, we've either set streamableHttpTransport as the
86
+ // ActiveTransport, or else we will have disposed it in the !IsSuccessStatusCode else block.
87
+ await streamableHttpTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
88
+ throw ;
89
+ }
90
+ }
91
+
92
+ private async Task InitializeSseTransportAsync ( JsonRpcMessage message , CancellationToken cancellationToken )
93
+ {
94
+ var sseTransport = new SseClientSessionTransport ( _name , _options , _httpClient , _messageChannel , _loggerFactory ) ;
95
+
96
+ try
97
+ {
98
+ LogAttemptingSSE ( _name ) ;
99
+ await sseTransport . ConnectAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
100
+ await sseTransport . SendMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
101
+
102
+ LogUsingSSE ( _name ) ;
103
+ ActiveTransport = sseTransport ;
104
+ }
105
+ catch
106
+ {
107
+ await sseTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
108
+ throw ;
109
+ }
110
+ }
111
+
112
+ public async ValueTask DisposeAsync ( )
113
+ {
114
+ try
115
+ {
116
+ if ( ActiveTransport is not null )
117
+ {
118
+ await ActiveTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
119
+ }
120
+ }
121
+ finally
122
+ {
123
+ // In the majority of cases, either the Streamable HTTP transport or SSE transport has completed the channel by now.
124
+ // However, this may not be the case if HttpClient throws during the initial request due to misconfiguration.
125
+ _messageChannel . Writer . TryComplete ( ) ;
126
+ }
127
+ }
128
+
129
+ [ LoggerMessage ( Level = LogLevel . Debug , Message = "{EndpointName} attempting to connect using Streamable HTTP transport." ) ]
130
+ private partial void LogAttemptingStreamableHttp ( string endpointName ) ;
131
+
132
+ [ LoggerMessage ( Level = LogLevel . Information , Message = "{EndpointName} streamable HTTP transport failed with status code {StatusCode}, falling back to SSE transport." ) ]
133
+ private partial void LogStreamableHttpFailed ( string endpointName , HttpStatusCode statusCode ) ;
134
+
135
+ [ LoggerMessage ( Level = LogLevel . Information , Message = "{EndpointName} using Streamable HTTP transport." ) ]
136
+ private partial void LogUsingStreamableHttp ( string endpointName ) ;
137
+
138
+ [ LoggerMessage ( Level = LogLevel . Debug , Message = "{EndpointName} attempting to connect using SSE transport." ) ]
139
+ private partial void LogAttemptingSSE ( string endpointName ) ;
140
+
141
+ [ LoggerMessage ( Level = LogLevel . Information , Message = "{EndpointName} using SSE transport." ) ]
142
+ private partial void LogUsingSSE ( string endpointName ) ;
143
+ }
0 commit comments