Skip to content

Commit efdd5cc

Browse files
committed
feat(mcp): webflux support filter
1 parent 261554b commit efdd5cc

File tree

4 files changed

+316
-7
lines changed

4 files changed

+316
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package io.modelcontextprotocol.server.filter;
2+
3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
6+
import io.modelcontextprotocol.spec.McpSchema;
7+
import io.modelcontextprotocol.spec.McpServerSession;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
import org.springframework.web.reactive.function.server.HandlerFilterFunction;
11+
import org.springframework.web.reactive.function.server.HandlerFunction;
12+
import org.springframework.web.reactive.function.server.ServerRequest;
13+
import org.springframework.web.reactive.function.server.ServerResponse;
14+
import reactor.core.publisher.Mono;
15+
16+
import java.util.Objects;
17+
import java.util.function.Function;
18+
19+
public abstract class CallToolHandlerFilter implements HandlerFilterFunction<ServerResponse, ServerResponse> {
20+
21+
private static final Logger logger = LoggerFactory.getLogger(CallToolHandlerFilter.class);
22+
23+
private final ObjectMapper objectMapper = new ObjectMapper();
24+
25+
private static final String DEFAULT_MESSAGE_PATH = "/mcp/message";
26+
27+
public final static McpSchema.CallToolResult PASS = null;
28+
29+
private Function<String, McpServerSession> sessionFunction;
30+
31+
/**
32+
* Filter incoming requests to handle tool calls. Processes
33+
* {@linkplain io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest JSONRPCRequest}
34+
* that match the configured path and method.
35+
* @param request The incoming server request
36+
* @param next The next handler in the chain
37+
* @return The filtered response
38+
*/
39+
@Override
40+
public Mono<ServerResponse> filter(ServerRequest request, HandlerFunction<ServerResponse> next) {
41+
if (!Objects.equals(request.path(), matchPath())) {
42+
return next.handle(request);
43+
}
44+
45+
return request.bodyToMono(McpSchema.JSONRPCRequest.class)
46+
.flatMap(jsonrpcRequest -> handleJsonRpcRequest(request, jsonrpcRequest, next));
47+
}
48+
49+
private Mono<ServerResponse> handleJsonRpcRequest(ServerRequest request, McpSchema.JSONRPCRequest jsonrpcRequest,
50+
HandlerFunction<ServerResponse> next) {
51+
ServerRequest newRequest;
52+
try {
53+
newRequest = ServerRequest.from(request).body(objectMapper.writeValueAsString(jsonrpcRequest)).build();
54+
}
55+
catch (JsonProcessingException e) {
56+
return Mono.error(e);
57+
}
58+
59+
if (skipFilter(jsonrpcRequest)) {
60+
return next.handle(newRequest);
61+
}
62+
63+
return handleToolCallRequest(newRequest, jsonrpcRequest, next);
64+
}
65+
66+
private Mono<ServerResponse> handleToolCallRequest(ServerRequest newRequest,
67+
McpSchema.JSONRPCRequest jsonrpcRequest, HandlerFunction<ServerResponse> next) {
68+
McpServerSession session = newRequest.queryParam("sessionId")
69+
.map(sessionId -> sessionFunction.apply(sessionId))
70+
.orElse(null);
71+
72+
if (Objects.isNull(session)) {
73+
return next.handle(newRequest);
74+
}
75+
76+
McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(jsonrpcRequest.params(),
77+
McpSchema.CallToolRequest.class);
78+
McpSchema.CallToolResult callToolResult = doFilter(newRequest, callToolRequest);
79+
if (Objects.equals(PASS, callToolResult)) {
80+
return next.handle(newRequest);
81+
}
82+
else {
83+
return session.sendResponse(jsonrpcRequest.id(), callToolResult, null).then(ServerResponse.ok().build());
84+
}
85+
}
86+
87+
private boolean skipFilter(McpSchema.JSONRPCRequest jsonrpcRequest) {
88+
if (!Objects.equals(jsonrpcRequest.method(), matchMethod())) {
89+
return true;
90+
}
91+
92+
if (Objects.isNull(sessionFunction)) {
93+
logger.error("No session function provided, skip CallToolRequest filter");
94+
return true;
95+
}
96+
97+
return false;
98+
}
99+
100+
/**
101+
* Abstract method to be implemented by subclasses to handle tool call requests.
102+
* @param request The incoming server request. Contains HTTP information such as:
103+
* request path, request headers, request parameters. Note that the request body has
104+
* already been extracted and deserialized into the callToolRequest parameter, so
105+
* there's no need to extract the body from the ServerRequest again.
106+
* @param callToolRequest The deserialized call tool request object
107+
* @return A CallToolResult object if the current filter handles the request
108+
* (subsequent filters will not be executed), or
109+
* {@linkplain CallToolHandlerFilter#PASS PASS} if the current filter does not handle
110+
* the request (execution will continue to subsequent filters in the chain).
111+
*/
112+
public abstract McpSchema.CallToolResult doFilter(ServerRequest request, McpSchema.CallToolRequest callToolRequest);
113+
114+
/**
115+
* Returns the method name to match for handling tool calls.
116+
* @return The method name to match
117+
*/
118+
public String matchMethod() {
119+
return McpSchema.METHOD_TOOLS_CALL;
120+
}
121+
122+
/**
123+
* Returns the path to match for handling tool calls.
124+
* @return The path to match
125+
*/
126+
public String matchPath() {
127+
return DEFAULT_MESSAGE_PATH;
128+
}
129+
130+
/**
131+
* Set the session provider function.
132+
* @param transportProvider The SSE server transport provider used to obtain sessions
133+
*/
134+
public void applySession(WebFluxSseServerTransportProvider transportProvider) {
135+
this.sessionFunction = transportProvider::getSession;
136+
}
137+
138+
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

+44-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package io.modelcontextprotocol.server.transport;
22

33
import java.io.IOException;
4-
import java.util.Map;
54
import java.util.concurrent.ConcurrentHashMap;
65

76
import com.fasterxml.jackson.core.type.TypeReference;
@@ -14,6 +13,7 @@
1413
import io.modelcontextprotocol.util.Assert;
1514
import org.slf4j.Logger;
1615
import org.slf4j.LoggerFactory;
16+
import org.springframework.web.reactive.function.server.*;
1717
import reactor.core.Exceptions;
1818
import reactor.core.publisher.Flux;
1919
import reactor.core.publisher.FluxSink;
@@ -22,10 +22,6 @@
2222
import org.springframework.http.HttpStatus;
2323
import org.springframework.http.MediaType;
2424
import org.springframework.http.codec.ServerSentEvent;
25-
import org.springframework.web.reactive.function.server.RouterFunction;
26-
import org.springframework.web.reactive.function.server.RouterFunctions;
27-
import org.springframework.web.reactive.function.server.ServerRequest;
28-
import org.springframework.web.reactive.function.server.ServerResponse;
2925

3026
/**
3127
* Server-side implementation of the MCP (Model Context Protocol) HTTP transport using
@@ -84,6 +80,12 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
8480

8581
public static final String DEFAULT_BASE_URL = "";
8682

83+
/**
84+
* Default filter function for handling requests, do nothing
85+
*/
86+
public static final HandlerFilterFunction<ServerResponse, ServerResponse> DEFAULT_REQUEST_FILTER = ((request,
87+
next) -> next.handle(request));
88+
8789
private final ObjectMapper objectMapper;
8890

8991
/**
@@ -149,10 +151,28 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
149151
*/
150152
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
151153
String sseEndpoint) {
154+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, DEFAULT_REQUEST_FILTER);
155+
}
156+
157+
/**
158+
* Constructs a new WebFlux SSE server transport provider instance.
159+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
160+
* of MCP messages. Must not be null.
161+
* @param baseUrl webflux message base path
162+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
163+
* messages. This endpoint will be communicated to clients during SSE connection
164+
* setup. Must not be null.
165+
* @param requestFilter The filter function to apply to incoming requests, which may
166+
* be sse or message request.
167+
* @throws IllegalArgumentException if either parameter is null
168+
*/
169+
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
170+
String sseEndpoint, HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter) {
152171
Assert.notNull(objectMapper, "ObjectMapper must not be null");
153172
Assert.notNull(baseUrl, "Message base path must not be null");
154173
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
155174
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
175+
Assert.notNull(requestFilter, "Request filter must not be null");
156176

157177
this.objectMapper = objectMapper;
158178
this.baseUrl = baseUrl;
@@ -161,6 +181,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
161181
this.routerFunction = RouterFunctions.route()
162182
.GET(this.sseEndpoint, this::handleSseConnection)
163183
.POST(this.messageEndpoint, this::handleMessage)
184+
.filter(requestFilter)
164185
.build();
165186
}
166187

@@ -245,6 +266,14 @@ public RouterFunction<?> getRouterFunction() {
245266
return this.routerFunction;
246267
}
247268

269+
/**
270+
* Returns the McpServerSession associated with the given session ID.
271+
* @return session The McpServerSession associated with the given session ID, or null
272+
*/
273+
public McpServerSession getSession(String sessionId) {
274+
return sessions.get(sessionId);
275+
}
276+
248277
/**
249278
* Handles new SSE connection requests from clients. Creates a new session for each
250279
* connection and sets up the SSE event stream.
@@ -397,6 +426,8 @@ public static class Builder {
397426

398427
private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
399428

429+
private HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter = DEFAULT_REQUEST_FILTER;
430+
400431
/**
401432
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
402433
* messages.
@@ -447,6 +478,12 @@ public Builder sseEndpoint(String sseEndpoint) {
447478
return this;
448479
}
449480

481+
public Builder requestFilter(HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter) {
482+
Assert.notNull(requestFilter, "requestFilter must not be null");
483+
this.requestFilter = requestFilter;
484+
return this;
485+
}
486+
450487
/**
451488
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
452489
* configured settings.
@@ -457,7 +494,8 @@ public WebFluxSseServerTransportProvider build() {
457494
Assert.notNull(objectMapper, "ObjectMapper must be set");
458495
Assert.notNull(messageEndpoint, "Message endpoint must be set");
459496

460-
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
497+
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
498+
requestFilter);
461499
}
462500

463501
}

0 commit comments

Comments
 (0)