1
1
package io .modelcontextprotocol .server .transport ;
2
2
3
3
import java .io .IOException ;
4
- import java .util .Map ;
5
4
import java .util .concurrent .ConcurrentHashMap ;
6
5
7
6
import com .fasterxml .jackson .core .type .TypeReference ;
14
13
import io .modelcontextprotocol .util .Assert ;
15
14
import org .slf4j .Logger ;
16
15
import org .slf4j .LoggerFactory ;
16
+ import org .springframework .web .reactive .function .server .*;
17
17
import reactor .core .Exceptions ;
18
18
import reactor .core .publisher .Flux ;
19
19
import reactor .core .publisher .FluxSink ;
22
22
import org .springframework .http .HttpStatus ;
23
23
import org .springframework .http .MediaType ;
24
24
import org .springframework .http .codec .ServerSentEvent ;
25
- import org .springframework .web .reactive .function .server .RouterFunction ;
26
- import org .springframework .web .reactive .function .server .RouterFunctions ;
27
- import org .springframework .web .reactive .function .server .ServerRequest ;
28
- import org .springframework .web .reactive .function .server .ServerResponse ;
29
25
30
26
/**
31
27
* Server-side implementation of the MCP (Model Context Protocol) HTTP transport using
@@ -84,6 +80,12 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
84
80
85
81
public static final String DEFAULT_BASE_URL = "" ;
86
82
83
+ /**
84
+ * Default filter function for handling requests, do nothing
85
+ */
86
+ public static final HandlerFilterFunction <ServerResponse , ServerResponse > DEFAULT_REQUEST_FILTER = ((request ,
87
+ next ) -> next .handle (request ));
88
+
87
89
private final ObjectMapper objectMapper ;
88
90
89
91
/**
@@ -149,10 +151,28 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
149
151
*/
150
152
public WebFluxSseServerTransportProvider (ObjectMapper objectMapper , String baseUrl , String messageEndpoint ,
151
153
String sseEndpoint ) {
154
+ this (objectMapper , baseUrl , messageEndpoint , sseEndpoint , DEFAULT_REQUEST_FILTER );
155
+ }
156
+
157
+ /**
158
+ * Constructs a new WebFlux SSE server transport provider instance.
159
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
160
+ * of MCP messages. Must not be null.
161
+ * @param baseUrl webflux message base path
162
+ * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
163
+ * messages. This endpoint will be communicated to clients during SSE connection
164
+ * setup. Must not be null.
165
+ * @param requestFilter The filter function to apply to incoming requests, which may
166
+ * be sse or message request.
167
+ * @throws IllegalArgumentException if either parameter is null
168
+ */
169
+ public WebFluxSseServerTransportProvider (ObjectMapper objectMapper , String baseUrl , String messageEndpoint ,
170
+ String sseEndpoint , HandlerFilterFunction <ServerResponse , ServerResponse > requestFilter ) {
152
171
Assert .notNull (objectMapper , "ObjectMapper must not be null" );
153
172
Assert .notNull (baseUrl , "Message base path must not be null" );
154
173
Assert .notNull (messageEndpoint , "Message endpoint must not be null" );
155
174
Assert .notNull (sseEndpoint , "SSE endpoint must not be null" );
175
+ Assert .notNull (requestFilter , "Request filter must not be null" );
156
176
157
177
this .objectMapper = objectMapper ;
158
178
this .baseUrl = baseUrl ;
@@ -161,6 +181,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
161
181
this .routerFunction = RouterFunctions .route ()
162
182
.GET (this .sseEndpoint , this ::handleSseConnection )
163
183
.POST (this .messageEndpoint , this ::handleMessage )
184
+ .filter (requestFilter )
164
185
.build ();
165
186
}
166
187
@@ -245,6 +266,14 @@ public RouterFunction<?> getRouterFunction() {
245
266
return this .routerFunction ;
246
267
}
247
268
269
+ /**
270
+ * Returns the McpServerSession associated with the given session ID.
271
+ * @return session The McpServerSession associated with the given session ID, or null
272
+ */
273
+ public McpServerSession getSession (String sessionId ) {
274
+ return sessions .get (sessionId );
275
+ }
276
+
248
277
/**
249
278
* Handles new SSE connection requests from clients. Creates a new session for each
250
279
* connection and sets up the SSE event stream.
@@ -397,6 +426,8 @@ public static class Builder {
397
426
398
427
private String sseEndpoint = DEFAULT_SSE_ENDPOINT ;
399
428
429
+ private HandlerFilterFunction <ServerResponse , ServerResponse > requestFilter = DEFAULT_REQUEST_FILTER ;
430
+
400
431
/**
401
432
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
402
433
* messages.
@@ -447,6 +478,12 @@ public Builder sseEndpoint(String sseEndpoint) {
447
478
return this ;
448
479
}
449
480
481
+ public Builder requestFilter (HandlerFilterFunction <ServerResponse , ServerResponse > requestFilter ) {
482
+ Assert .notNull (requestFilter , "requestFilter must not be null" );
483
+ this .requestFilter = requestFilter ;
484
+ return this ;
485
+ }
486
+
450
487
/**
451
488
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
452
489
* configured settings.
@@ -457,7 +494,8 @@ public WebFluxSseServerTransportProvider build() {
457
494
Assert .notNull (objectMapper , "ObjectMapper must be set" );
458
495
Assert .notNull (messageEndpoint , "Message endpoint must be set" );
459
496
460
- return new WebFluxSseServerTransportProvider (objectMapper , baseUrl , messageEndpoint , sseEndpoint );
497
+ return new WebFluxSseServerTransportProvider (objectMapper , baseUrl , messageEndpoint , sseEndpoint ,
498
+ requestFilter );
461
499
}
462
500
463
501
}
0 commit comments