Skip to content

Commit 2976a72

Browse files
More efficient http #156
1 parent f348a83 commit 2976a72

File tree

8 files changed

+140
-111
lines changed

8 files changed

+140
-111
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ else if (MESSAGE_EVENT_TYPE.equals(event.event())) {
208208
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data());
209209
s.next(message);
210210
}
211-
catch (IOException ioException) {
212-
s.error(ioException);
211+
catch (RuntimeException ioOrIllegalException) {
212+
s.error(ioOrIllegalException);
213213
}
214214
}
215215
else {

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.modelcontextprotocol.server.transport;
22

33
import java.io.IOException;
4+
import java.io.UncheckedIOException;
45
import java.util.Map;
56
import java.util.concurrent.ConcurrentHashMap;
67

@@ -323,8 +324,8 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
323324
.bodyValue(new McpError(error.getMessage()));
324325
});
325326
}
326-
catch (IllegalArgumentException | IOException e) {
327-
logger.error("Failed to deserialize message: {}", e.getMessage());
327+
catch (IllegalArgumentException | UncheckedIOException e) {
328+
logger.error("Failed to deserialize message", e);
328329
return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
329330
}
330331
});

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ else if (MESSAGE_EVENT_TYPE.equals(event.type())) {
362362
logger.error("Received unrecognized SSE event type: {}", event.type());
363363
}
364364
}
365-
catch (IOException e) {
365+
catch (RuntimeException e) {
366366
logger.error("Error processing SSE event", e);
367367
future.completeExceptionally(e);
368368
}

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package io.modelcontextprotocol.server;
66

7+
import java.time.Duration;
78
import java.util.HashMap;
89
import java.util.List;
910
import java.util.Map;
@@ -264,7 +265,7 @@ private static class AsyncServerImpl extends McpAsyncServer {
264265

265266
private final ConcurrentHashMap<String, McpServerFeatures.AsyncPromptSpecification> prompts = new ConcurrentHashMap<>();
266267

267-
// FIXME: this field is deprecated and should be remvoed together with the
268+
// FIXME: this field is deprecated and should be removed together with the
268269
// broadcasting loggingNotification.
269270
private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG;
270271

@@ -330,9 +331,9 @@ private static class AsyncServerImpl extends McpAsyncServer {
330331
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED,
331332
asyncRootsListChangedNotificationHandler(rootsChangeConsumers));
332333

333-
mcpTransportProvider
334-
.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), transport,
335-
this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers));
334+
mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(),
335+
transport, this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers,
336+
Duration.ofSeconds(10)));
336337
}
337338

338339
// ---------------------------------------

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java

+19-16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.io.BufferedReader;
77
import java.io.IOException;
88
import java.io.PrintWriter;
9+
import java.io.UncheckedIOException;
910
import java.util.Map;
1011
import java.util.UUID;
1112
import java.util.concurrent.ConcurrentHashMap;
@@ -240,18 +241,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
240241
@Override
241242
protected void doPost(HttpServletRequest request, HttpServletResponse response)
242243
throws ServletException, IOException {
243-
244244
if (isClosing.get()) {
245245
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
246246
return;
247247
}
248-
249248
String requestURI = request.getRequestURI();
250249
if (!requestURI.endsWith(messageEndpoint)) {
251250
response.sendError(HttpServletResponse.SC_NOT_FOUND);
252251
return;
253252
}
254-
255253
// Get the session ID from the request parameter
256254
String sessionId = request.getParameter("sessionId");
257255
if (sessionId == null) {
@@ -277,24 +275,29 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
277275
writer.flush();
278276
return;
279277
}
280-
281278
try {
282-
BufferedReader reader = request.getReader();
283-
StringBuilder body = new StringBuilder();
284-
String line;
285-
while ((line = reader.readLine()) != null) {
286-
body.append(line);
287-
}
288-
289-
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());
290-
291-
// Process the message through the session's handle method
279+
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, request.getReader());
292280
session.handle(message).block(); // Block for Servlet compatibility
293-
294281
response.setStatus(HttpServletResponse.SC_OK);
295282
}
283+
catch (IllegalArgumentException | UncheckedIOException ex) {
284+
try {
285+
McpError mcpError = new McpError(ex.getMessage());
286+
response.setContentType(APPLICATION_JSON);
287+
response.setCharacterEncoding(UTF_8);
288+
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
289+
String jsonError = objectMapper.writeValueAsString(mcpError);
290+
PrintWriter writer = response.getWriter();
291+
writer.write(jsonError);
292+
writer.flush();
293+
}
294+
catch (IOException ex2) {
295+
logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex2.getMessage());
296+
response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message");
297+
}
298+
}
296299
catch (Exception e) {
297-
logger.error("Error processing message: {}", e.getMessage());
300+
logger.error("Error processing message", e);
298301
try {
299302
McpError mcpError = new McpError(e.getMessage());
300303
response.setContentType(APPLICATION_JSON);

mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java

+25-18
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
package io.modelcontextprotocol.spec;
66

7+
import java.io.BufferedReader;
78
import java.io.IOException;
9+
import java.io.Reader;
10+
import java.io.StringReader;
811
import java.util.ArrayList;
912
import java.util.HashMap;
1013
import java.util.List;
@@ -140,32 +143,36 @@ public sealed interface Request
140143
/**
141144
* Deserializes a JSON string into a JSONRPCMessage object.
142145
* @param objectMapper The ObjectMapper instance to use for deserialization
143-
* @param jsonText The JSON string to deserialize
146+
* @param inputStream The JSON string to deserialize
144147
* @return A JSONRPCMessage instance using either the {@link JSONRPCRequest},
145148
* {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes.
146-
* @throws IOException If there's an error during deserialization
147149
* @throws IllegalArgumentException If the JSON structure doesn't match any known
148150
* message type
149151
*/
150-
public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText)
151-
throws IOException {
152-
153-
logger.debug("Received JSON message: {}", jsonText);
154-
155-
var map = objectMapper.readValue(jsonText, MAP_TYPE_REF);
156-
157-
// Determine message type based on specific JSON structure
158-
if (map.containsKey("method") && map.containsKey("id")) {
159-
return objectMapper.convertValue(map, JSONRPCRequest.class);
160-
}
161-
else if (map.containsKey("method") && !map.containsKey("id")) {
162-
return objectMapper.convertValue(map, JSONRPCNotification.class);
152+
public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, BufferedReader inputStream) {
153+
try {
154+
var map = objectMapper.readValue(inputStream, MAP_TYPE_REF);
155+
// Determine message type based on specific JSON structure
156+
if (map.containsKey("method") && map.containsKey("id")) {
157+
return objectMapper.convertValue(map, JSONRPCRequest.class);
158+
}
159+
else if (map.containsKey("method") && !map.containsKey("id")) {
160+
return objectMapper.convertValue(map, JSONRPCNotification.class);
161+
}
162+
else if (map.containsKey("result") || map.containsKey("error")) {
163+
return objectMapper.convertValue(map, JSONRPCResponse.class);
164+
}
165+
throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + map);
163166
}
164-
else if (map.containsKey("result") || map.containsKey("error")) {
165-
return objectMapper.convertValue(map, JSONRPCResponse.class);
167+
catch (IOException e) {
168+
throw new java.io.UncheckedIOException(e);
166169
}
170+
}
167171

168-
throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText);
172+
public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String input) {
173+
Reader inputString = new StringReader(input);
174+
BufferedReader reader = new BufferedReader(inputString);
175+
return deserializeJsonRpcMessage(objectMapper, reader);
169176
}
170177

171178
// ---------------------------

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

+12-10
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,30 @@ public class McpServerSession implements McpSession {
5353

5454
private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED);
5555

56+
private final Duration requestTimeout;
57+
5658
/**
5759
* Creates a new server session with the given parameters and the transport to use.
5860
* @param id session id
5961
* @param transport the transport to use
6062
* @param initHandler called when a
6163
* {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the
6264
* server
63-
* @param initNotificationHandler called when a
64-
* {@link McpSchema.METHOD_NOTIFICATION_INITIALIZED} is received.
65+
* @param initNotificationHandler called when a {@link McpSchema.METHOD_INITIALIZE }
66+
* is received.
6567
* @param requestHandlers map of request handlers to use
6668
* @param notificationHandlers map of notification handlers to use
6769
*/
6870
public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler,
6971
InitNotificationHandler initNotificationHandler, Map<String, RequestHandler<?>> requestHandlers,
70-
Map<String, NotificationHandler> notificationHandlers) {
72+
Map<String, NotificationHandler> notificationHandlers, Duration requestTimeout) {
7173
this.id = id;
7274
this.transport = transport;
7375
this.initRequestHandler = initHandler;
7476
this.initNotificationHandler = initNotificationHandler;
7577
this.requestHandlers = requestHandlers;
7678
this.notificationHandlers = notificationHandlers;
79+
this.requestTimeout = requestTimeout;
7780
}
7881

7982
/**
@@ -116,7 +119,7 @@ public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReferenc
116119
this.pendingResponses.remove(requestId);
117120
sink.error(error);
118121
});
119-
}).timeout(Duration.ofSeconds(10)).handle((jsonRpcResponse, sink) -> {
122+
}).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> {
120123
if (jsonRpcResponse.error() != null) {
121124
sink.error(new McpError(jsonRpcResponse.error()));
122125
}
@@ -197,6 +200,7 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
197200
return Mono.defer(() -> {
198201
Mono<?> resultMono;
199202
if (McpSchema.METHOD_INITIALIZE.equals(request.method())) {
203+
200204
// TODO handle situation where already initialized!
201205
McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(),
202206
new TypeReference<McpSchema.InitializeRequest>() {
@@ -254,13 +258,11 @@ record MethodNotFoundError(String method, String message, Object data) {
254258
}
255259

256260
static MethodNotFoundError getMethodNotFoundError(String method) {
257-
switch (method) {
258-
case McpSchema.METHOD_ROOTS_LIST:
259-
return new MethodNotFoundError(method, "Roots not supported",
260-
Map.of("reason", "Client does not have roots capability"));
261-
default:
262-
return new MethodNotFoundError(method, "Method not found: " + method, null);
261+
if (method.equals(McpSchema.METHOD_ROOTS_LIST)) {
262+
return new MethodNotFoundError(method, "Roots not supported",
263+
Map.of("reason", "Client does not have roots capability"));
263264
}
265+
return new MethodNotFoundError(method, "Method not found: " + method, null);
264266
}
265267

266268
@Override

0 commit comments

Comments
 (0)