Skip to content

Commit d7d9eab

Browse files
Pass session id's to MCP endpoints. (#466)
* Pass session id's to MCP endpoints. * Update src/ModelContextProtocol/Client/McpClient.cs Co-authored-by: Stephen Halter <[email protected]> * Update src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs Co-authored-by: Stephen Halter <[email protected]> * Make init callback asynchronous. * Address feedback. --------- Co-authored-by: Stephen Halter <[email protected]>
1 parent fa017c0 commit d7d9eab

File tree

20 files changed

+166
-21
lines changed

20 files changed

+166
-21
lines changed

src/ModelContextProtocol.AspNetCore/SseHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public async Task HandleSseRequestAsync(HttpContext context)
3131

3232
var requestPath = (context.Request.PathBase + context.Request.Path).ToString();
3333
var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)];
34-
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}");
34+
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId);
3535

3636
var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User);
3737
await using var httpMcpSession = new HttpMcpSession<SseResponseStreamTransport>(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider);

src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using Microsoft.AspNetCore.WebUtilities;
55
using Microsoft.Extensions.Logging;
66
using Microsoft.Extensions.Options;
7-
using Microsoft.Extensions.Primitives;
87
using Microsoft.Net.Http.Headers;
98
using ModelContextProtocol.AspNetCore.Stateless;
109
using ModelContextProtocol.Protocol;
@@ -136,6 +135,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
136135
var transport = new StreamableHttpServerTransport
137136
{
138137
Stateless = true,
138+
SessionId = sessionId,
139139
};
140140
session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId);
141141
}
@@ -184,7 +184,10 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> StartNewS
184184
if (!HttpServerTransportOptions.Stateless)
185185
{
186186
sessionId = MakeNewSessionId();
187-
transport = new();
187+
transport = new()
188+
{
189+
SessionId = sessionId,
190+
};
188191
context.Response.Headers["mcp-session-id"] = sessionId;
189192
}
190193
else
@@ -286,21 +289,19 @@ internal static string MakeNewSessionId()
286289

287290
private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport)
288291
{
289-
context.Response.OnStarting(() =>
292+
transport.OnInitRequestReceived = initRequestParams =>
290293
{
291294
var statelessId = new StatelessSessionId
292295
{
293-
ClientInfo = transport?.InitializeRequest?.ClientInfo,
296+
ClientInfo = initRequestParams?.ClientInfo,
294297
UserIdClaim = GetUserIdClaim(context.User),
295298
};
296299

297300
var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId);
298-
var sessionId = Protector.Protect(sessionJson);
299-
300-
context.Response.Headers["mcp-session-id"] = sessionId;
301-
302-
return Task.CompletedTask;
303-
});
301+
transport.SessionId = Protector.Protect(sessionJson);
302+
context.Response.Headers["mcp-session-id"] = transport.SessionId;
303+
return ValueTask.CompletedTask;
304+
};
304305
}
305306

306307
internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted)

src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOp
4545

4646
public ChannelReader<JsonRpcMessage> MessageReader => _messageChannel.Reader;
4747

48+
string? ITransport.SessionId => ActiveTransport?.SessionId;
49+
4850
/// <inheritdoc/>
4951
public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
5052
{

src/ModelContextProtocol.Core/Client/McpClient.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.Extensions.Logging;
22
using ModelContextProtocol.Protocol;
3+
using System.Diagnostics;
34
using System.Text.Json;
45

56
namespace ModelContextProtocol.Client;
@@ -93,6 +94,20 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL
9394
}
9495
}
9596

97+
/// <inheritdoc/>
98+
public string? SessionId
99+
{
100+
get
101+
{
102+
if (_sessionTransport is null)
103+
{
104+
throw new InvalidOperationException("Must have already initialized a session when invoking this property.");
105+
}
106+
107+
return _sessionTransport.SessionId;
108+
}
109+
}
110+
96111
/// <inheritdoc/>
97112
public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected.");
98113

src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa
2727
private readonly CancellationTokenSource _connectionCts;
2828
private readonly ILogger _logger;
2929

30-
private string? _mcpSessionId;
3130
private Task? _getReceiveTask;
3231

3332
public StreamableHttpClientSessionTransport(
@@ -85,7 +84,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
8584
},
8685
};
8786

88-
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId);
87+
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId);
8988

9089
var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
9190

@@ -124,7 +123,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
124123
// We've successfully initialized! Copy session-id and start GET request if any.
125124
if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues))
126125
{
127-
_mcpSessionId = sessionIdValues.FirstOrDefault();
126+
SessionId = sessionIdValues.FirstOrDefault();
128127
}
129128

130129
_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
@@ -170,7 +169,7 @@ private async Task ReceiveUnsolicitedMessagesAsync()
170169
// Send a GET request to handle any unsolicited messages not sent over a POST response.
171170
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
172171
request.Headers.Accept.Add(s_textEventStreamMediaType);
173-
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId);
172+
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId);
174173

175174
using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false);
176175

src/ModelContextProtocol.Core/IMcpEndpoint.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ namespace ModelContextProtocol;
2828
/// </remarks>
2929
public interface IMcpEndpoint : IAsyncDisposable
3030
{
31+
/// <summary>Gets an identifier associated with the current MCP session.</summary>
32+
/// <remarks>
33+
/// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE.
34+
/// Can return <see langword="null"/> if the session hasn't initialized or if the transport doesn't
35+
/// support multiple sessions (as is the case with STDIO).
36+
/// </remarks>
37+
string? SessionId { get; }
38+
3139
/// <summary>
3240
/// Sends a JSON-RPC request to the connected endpoint and waits for a response.
3341
/// </summary>

src/ModelContextProtocol.Core/Protocol/ITransport.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ namespace ModelContextProtocol.Protocol;
2525
/// </remarks>
2626
public interface ITransport : IAsyncDisposable
2727
{
28+
/// <summary>Gets an identifier associated with the current MCP session.</summary>
29+
/// <remarks>
30+
/// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE.
31+
/// Can return <see langword="null"/> if the session hasn't initialized or if the transport doesn't
32+
/// support multiple sessions (as is the case with STDIO).
33+
/// </remarks>
34+
string? SessionId { get; }
35+
2836
/// <summary>
2937
/// Gets a channel reader for receiving messages from the transport.
3038
/// </summary>

src/ModelContextProtocol.Core/Protocol/TransportBase.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ internal TransportBase(string name, Channel<JsonRpcMessage>? messageChannel, ILo
5959
/// <summary>Gets the logger used by this transport.</summary>
6060
private protected ILogger Logger => _logger;
6161

62+
/// <inheritdoc/>
63+
public virtual string? SessionId { get; protected set; }
64+
6265
/// <summary>
6366
/// Gets the name that identifies this transport endpoint in logs.
6467
/// </summary>

src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace ModelContextProtocol.Server;
66
internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer
77
{
88
public string EndpointName => server.EndpointName;
9+
public string? SessionId => transport?.SessionId ?? server.SessionId;
910
public ClientCapabilities? ClientCapabilities => server.ClientCapabilities;
1011
public Implementation? ClientInfo => server.ClientInfo;
1112
public McpServerOptions ServerOptions => server.ServerOptions;

src/ModelContextProtocol.Core/Server/McpServer.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ void Register<TPrimitive>(McpServerPrimitiveCollection<TPrimitive>? collection,
9696
InitializeSession(transport);
9797
}
9898

99+
/// <inheritdoc/>
100+
public string? SessionId => _sessionTransport.SessionId;
101+
102+
/// <inheritdoc/>
99103
public ServerCapabilities ServerCapabilities { get; } = new();
100104

101105
/// <inheritdoc />

src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace ModelContextProtocol.Server;
2323
/// These messages should be passed to <see cref="OnMessageReceivedAsync(JsonRpcMessage, CancellationToken)"/>.
2424
/// Defaults to "/message".
2525
/// </param>
26-
public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message") : ITransport
26+
/// <param name="sessionId">The identifier corresponding to the current MCP session.</param>
27+
public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message", string? sessionId = null) : ITransport
2728
{
2829
private readonly SseWriter _sseWriter = new(messageEndpoint);
2930
private readonly Channel<JsonRpcMessage> _incomingChannel = Channel.CreateBounded<JsonRpcMessage>(new BoundedChannelOptions(1)
@@ -49,6 +50,9 @@ public async Task RunAsync(CancellationToken cancellationToken)
4950
/// <inheritdoc/>
5051
public ChannelReader<JsonRpcMessage> MessageReader => _incomingChannel.Reader;
5152

53+
/// <inheritdoc/>
54+
public string? SessionId { get; } = sessionId;
55+
5256
/// <inheritdoc/>
5357
public async ValueTask DisposeAsync()
5458
{

src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport
1818

1919
public ChannelReader<JsonRpcMessage> MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages.");
2020

21+
string? ITransport.SessionId => parentTransport.SessionId;
22+
2123
/// <returns>
2224
/// True, if data was written to the respond body.
2325
/// False, if nothing was written because the request body did not contain any <see cref="JsonRpcRequest"/> messages to respond to.
@@ -79,10 +81,11 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella
7981
{
8082
_pendingRequest = request.Id;
8183

82-
// Store client capabilities so they can be serialized by "stateless" callers for use in later requests.
83-
if (parentTransport.Stateless && request.Method == RequestMethods.Initialize)
84+
// Invoke the initialize request callback if applicable.
85+
if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize)
8486
{
85-
parentTransport.InitializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams);
87+
var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams);
88+
await onInitRequest(initializeRequest).ConfigureAwait(false);
8689
}
8790
}
8891

src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,18 @@ public sealed class StreamableHttpServerTransport : ITransport
4646
public bool Stateless { get; init; }
4747

4848
/// <summary>
49-
/// Gets the initialize request if it was received by <see cref="HandlePostRequest(IDuplexPipe, CancellationToken)"/> and <see cref="Stateless"/> is set to <see langword="true"/>.
49+
/// Gets or sets a callback to be invoked before handling the initialize request.
5050
/// </summary>
51-
public InitializeRequestParams? InitializeRequest { get; internal set; }
51+
public Func<InitializeRequestParams?, ValueTask>? OnInitRequestReceived { get; set; }
5252

5353
/// <inheritdoc/>
5454
public ChannelReader<JsonRpcMessage> MessageReader => _incomingChannel.Reader;
5555

5656
internal ChannelWriter<JsonRpcMessage> MessageWriter => _incomingChannel.Writer;
5757

58+
/// <inheritdoc/>
59+
public string? SessionId { get; set; }
60+
5861
/// <summary>
5962
/// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by
6063
/// writing any unsolicited JSON-RPC messages sent via <see cref="SendMessageAsync"/>

tests/Common/Utils/TestServerTransport.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public class TestServerTransport : ITransport
1616

1717
public Action<JsonRpcMessage>? OnMessageSent { get; set; }
1818

19+
public string? SessionId => null;
20+
1921
public TestServerTransport()
2022
{
2123
_messageChannel = Channel.CreateUnbounded<JsonRpcMessage>(new UnboundedChannelOptions

tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ public async Task Connect_TestServer_ShouldProvideServerFields()
5252
// Assert
5353
Assert.NotNull(client.ServerCapabilities);
5454
Assert.NotNull(client.ServerInfo);
55+
56+
if (ClientTransportOptions.Endpoint.AbsolutePath.EndsWith("/sse"))
57+
{
58+
Assert.Null(client.SessionId);
59+
}
60+
else
61+
{
62+
Assert.NotNull(client.SessionId);
63+
}
5564
}
5665

5766
[Fact]
@@ -90,6 +99,35 @@ public async Task CallTool_Sse_EchoServer()
9099
Assert.Equal("Echo: Hello MCP!", textContent.Text);
91100
}
92101

102+
[Fact]
103+
public async Task CallTool_EchoSessionId_ReturnsTheSameSessionId()
104+
{
105+
// arrange
106+
107+
// act
108+
await using var client = await GetClientAsync();
109+
var result1 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken);
110+
var result2 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken);
111+
var result3 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken);
112+
113+
// assert
114+
Assert.NotNull(result1);
115+
Assert.NotNull(result2);
116+
Assert.NotNull(result3);
117+
118+
Assert.False(result1.IsError);
119+
Assert.False(result2.IsError);
120+
Assert.False(result3.IsError);
121+
122+
var textContent1 = Assert.Single(result1.Content);
123+
var textContent2 = Assert.Single(result2.Content);
124+
var textContent3 = Assert.Single(result3.Content);
125+
126+
Assert.NotNull(textContent1.Text);
127+
Assert.Equal(textContent1.Text, textContent2.Text);
128+
Assert.Equal(textContent1.Text, textContent3.Text);
129+
}
130+
93131
[Fact]
94132
public async Task ListResources_Sse_TestServer()
95133
{

tests/ModelContextProtocol.TestServer/Program.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ private static ToolsCapability ConfigureTools()
133133
"""),
134134
},
135135
new Tool()
136+
{
137+
Name = "echoSessionId",
138+
Description = "Echoes the session id back to the client.",
139+
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
140+
{
141+
"type": "object"
142+
}
143+
""", McpJsonUtilities.DefaultOptions),
144+
},
145+
new Tool()
136146
{
137147
Name = "sampleLLM",
138148
Description = "Samples from an LLM using MCP's sampling feature.",
@@ -170,6 +180,13 @@ private static ToolsCapability ConfigureTools()
170180
Content = [new Content() { Text = "Echo: " + message.ToString(), Type = "text" }]
171181
};
172182
}
183+
else if (request.Params?.Name == "echoSessionId")
184+
{
185+
return new CallToolResponse()
186+
{
187+
Content = [new Content() { Text = request.Server.SessionId, Type = "text" }]
188+
};
189+
}
173190
else if (request.Params?.Name == "sampleLLM")
174191
{
175192
if (request.Params?.Arguments is null ||

tests/ModelContextProtocol.TestSseServer/Program.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
128128
""", McpJsonUtilities.DefaultOptions),
129129
},
130130
new Tool()
131+
{
132+
Name = "echoSessionId",
133+
Description = "Echoes the session id back to the client.",
134+
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
135+
{
136+
"type": "object"
137+
}
138+
""", McpJsonUtilities.DefaultOptions),
139+
},
140+
new Tool()
131141
{
132142
Name = "sampleLLM",
133143
Description = "Samples from an LLM using MCP's sampling feature.",
@@ -168,6 +178,13 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
168178
Content = [new Content() { Text = "Echo: " + message.ToString(), Type = "text" }]
169179
};
170180
}
181+
else if (request.Params.Name == "echoSessionId")
182+
{
183+
return new CallToolResponse()
184+
{
185+
Content = [new Content() { Text = request.Server.SessionId, Type = "text" }]
186+
};
187+
}
171188
else if (request.Params.Name == "sampleLLM")
172189
{
173190
if (request.Params.Arguments is null ||

tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ private class NopTransport : ITransport, IClientTransport
109109
private readonly Channel<JsonRpcMessage> _channel = Channel.CreateUnbounded<JsonRpcMessage>();
110110

111111
public bool IsConnected => true;
112+
public string? SessionId => null;
112113

113114
public ChannelReader<JsonRpcMessage> MessageReader => _channel.Reader;
114115

0 commit comments

Comments
 (0)