diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/filter/CallToolHandlerFilter.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/filter/CallToolHandlerFilter.java new file mode 100644 index 00000000..987999ca --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/filter/CallToolHandlerFilter.java @@ -0,0 +1,137 @@ +package io.modelcontextprotocol.server.filter; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.reactive.function.server.HandlerFilterFunction; +import org.springframework.web.reactive.function.server.HandlerFunction; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.publisher.Mono; + +import java.util.Objects; +import java.util.function.Function; + +public abstract class CallToolHandlerFilter implements HandlerFilterFunction { + + private static final Logger logger = LoggerFactory.getLogger(CallToolHandlerFilter.class); + + private final ObjectMapper objectMapper = new ObjectMapper(); + + private static final String DEFAULT_MESSAGE_PATH = "/mcp/message"; + + public final static McpSchema.CallToolResult PASS = null; + + private Function sessionFunction; + + /** + * Filter incoming requests to handle tool calls. Processes + * {@linkplain io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest JSONRPCRequest} + * that match the configured path and method. + * @param request The incoming server request + * @param next The next handler in the chain + * @return The filtered response + */ + @Override + public Mono filter(ServerRequest request, HandlerFunction next) { + if (!Objects.equals(request.path(), matchPath())) { + return next.handle(request); + } + + return request.bodyToMono(McpSchema.JSONRPCRequest.class) + .flatMap(jsonrpcRequest -> handleJsonRpcRequest(request, jsonrpcRequest, next)); + } + + private Mono handleJsonRpcRequest(ServerRequest request, McpSchema.JSONRPCRequest jsonrpcRequest, + HandlerFunction next) { + ServerRequest newRequest; + try { + newRequest = ServerRequest.from(request).body(objectMapper.writeValueAsString(jsonrpcRequest)).build(); + } + catch (JsonProcessingException e) { + return Mono.error(e); + } + + if (skipFilter(jsonrpcRequest)) { + return next.handle(newRequest); + } + + return handleToolCallRequest(newRequest, jsonrpcRequest, next); + } + + private Mono handleToolCallRequest(ServerRequest newRequest, + McpSchema.JSONRPCRequest jsonrpcRequest, HandlerFunction next) { + McpServerSession session = newRequest.queryParam("sessionId") + .map(sessionId -> sessionFunction.apply(sessionId)) + .orElse(null); + + if (Objects.isNull(session)) { + return next.handle(newRequest); + } + + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(jsonrpcRequest.params(), + McpSchema.CallToolRequest.class); + McpSchema.CallToolResult callToolResult = doFilter(newRequest, callToolRequest); + if (Objects.equals(PASS, callToolResult)) { + return next.handle(newRequest); + } + else { + return session.sendResponse(jsonrpcRequest.id(), callToolResult, null).then(ServerResponse.ok().build()); + } + } + + private boolean skipFilter(McpSchema.JSONRPCRequest jsonrpcRequest) { + if (!Objects.equals(jsonrpcRequest.method(), matchMethod())) { + return true; + } + + if (Objects.isNull(sessionFunction)) { + logger.error("No session function provided, skip CallToolRequest filter"); + return true; + } + + return false; + } + + /** + * Abstract method to be implemented by subclasses to handle tool call requests. + * @param request The incoming server request. Contains HTTP information such as: + * request path, request headers, request parameters. Note that the request body has + * already been extracted and deserialized into the callToolRequest parameter, so + * there's no need to extract the body from the ServerRequest again. + * @param callToolRequest The deserialized call tool request object + * @return A CallToolResult object if not pass current filter (subsequent filters will + * not be executed), or {@linkplain CallToolHandlerFilter#PASS PASS} pass current + * filter(subsequent filters will be executed) + */ + public abstract McpSchema.CallToolResult doFilter(ServerRequest request, McpSchema.CallToolRequest callToolRequest); + + /** + * Returns the method name to match for handling tool calls. + * @return The method name to match + */ + public String matchMethod() { + return McpSchema.METHOD_TOOLS_CALL; + } + + /** + * Returns the path to match for handling tool calls. + * @return The path to match + */ + public String matchPath() { + return DEFAULT_MESSAGE_PATH; + } + + /** + * Set the session provider function. + * @param transportProvider The SSE server transport provider used to obtain sessions + */ + public void applySession(WebFluxSseServerTransportProvider transportProvider) { + this.sessionFunction = transportProvider::getSession; + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..216180e0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,7 +1,6 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; @@ -14,6 +13,7 @@ import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.web.reactive.function.server.*; import reactor.core.Exceptions; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; @@ -22,10 +22,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.server.RouterFunction; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; -import org.springframework.web.reactive.function.server.ServerResponse; /** * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using @@ -84,6 +80,12 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv public static final String DEFAULT_BASE_URL = ""; + /** + * Default filter function for handling requests, do nothing + */ + public static final HandlerFilterFunction DEFAULT_REQUEST_FILTER = ((request, + next) -> next.handle(request)); + private final ObjectMapper objectMapper; /** @@ -149,10 +151,28 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, DEFAULT_REQUEST_FILTER); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param requestFilter The filter function to apply to incoming requests, which may + * be sse or message request. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, HandlerFilterFunction requestFilter) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(requestFilter, "Request filter must not be null"); this.objectMapper = objectMapper; this.baseUrl = baseUrl; @@ -161,6 +181,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) + .filter(requestFilter) .build(); } @@ -245,6 +266,14 @@ public RouterFunction getRouterFunction() { return this.routerFunction; } + /** + * Returns the McpServerSession associated with the given session ID. + * @return session The McpServerSession associated with the given session ID, or null + */ + public McpServerSession getSession(String sessionId) { + return sessions.get(sessionId); + } + /** * Handles new SSE connection requests from clients. Creates a new session for each * connection and sets up the SSE event stream. @@ -397,6 +426,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private HandlerFilterFunction requestFilter = DEFAULT_REQUEST_FILTER; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -447,6 +478,12 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + public Builder requestFilter(HandlerFilterFunction requestFilter) { + Assert.notNull(requestFilter, "requestFilter must not be null"); + this.requestFilter = requestFilter; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -457,7 +494,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + requestFilter); } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 660f814d..4ff65589 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -22,15 +22,16 @@ import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.filter.CallToolHandlerFilter; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.*; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.web.reactive.function.server.ServerRequest; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -557,6 +558,128 @@ void testToolCallSuccess(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallWithFilterFail(String clientType) { + // Server + // Restart http server to add server filter + httpServer.disposeNow(); + var failCallResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("FAIL RESPONSE")), true); + var successCallResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("SUCCESS RESPONSE")), + false); + // User custom request filter + CallToolHandlerFilter customFilter = new CallToolHandlerFilter() { + @Override + public McpSchema.CallToolResult doFilter(ServerRequest request, CallToolRequest callToolRequest) { + if (request.headers().header("X-Custom-Header").isEmpty()) { + return failCallResponse; + } + return CallToolHandlerFilter.PASS; + } + + @Override + public String matchPath() { + return CUSTOM_MESSAGE_ENDPOINT; + } + }; + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .requestFilter(customFilter) + .build(); + customFilter.applySession(mcpServerTransportProvider); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool", "tool description", emptyJsonSchema), + (exchange, request) -> successCallResponse); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + // Client + var clientBuilder = clientBuilders.get(clientType); + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(failCallResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallWithFilterSuccess(String clientType) { + // Server + // Restart http server to add server filter + httpServer.disposeNow(); + var successCallResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("SUCCESS RESPONSE")), + false); + // User custom request filter + CallToolHandlerFilter customFilter = new CallToolHandlerFilter() { + @Override + public McpSchema.CallToolResult doFilter(ServerRequest request, CallToolRequest callToolRequest) { + return CallToolHandlerFilter.PASS; + } + + @Override + public String matchPath() { + return CUSTOM_MESSAGE_ENDPOINT; + } + }; + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .requestFilter(customFilter) + .build(); + customFilter.applySession(mcpServerTransportProvider); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool", "tool description", emptyJsonSchema), + (exchange, request) -> successCallResponse); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + // Client + var clientBuilder = clientBuilders.get(clientType); + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(successCallResponse); + } + + mcpServer.close(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testToolListChangeHandlingSuccess(String clientType) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d85..d9939628 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -143,6 +143,12 @@ public Mono sendNotification(String method, Object params) { return this.transport.sendMessage(jsonrpcNotification); } + public Mono sendResponse(Object id, Object result, McpSchema.JSONRPCResponse.JSONRPCError error) { + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, id, result, + error); + return this.transport.sendMessage(response); + } + /** * Called by the {@link McpServerTransportProvider} once the session is determined. * The purpose of this method is to dispatch the message to an appropriate handler as