Skip to content

Commit 8e9857f

Browse files
chemicLtzolov
authored andcommitted
Fix Streamable HTTP WebClient GET SSE handling
Signed-off-by: Dariusz Jędrzejczyk <[email protected]>
1 parent 9a63538 commit 8e9857f

File tree

4 files changed

+75
-14
lines changed

4 files changed

+75
-14
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,14 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
125125
}
126126

127127
private DefaultMcpTransportSession createTransportSession() {
128-
Supplier<Publisher<Void>> onClose = () -> {
129-
DefaultMcpTransportSession transportSession = this.activeSession.get();
130-
return transportSession.sessionId().isEmpty() ? Mono.empty()
131-
: webClient.delete().uri(this.endpoint).headers(httpHeaders -> {
132-
httpHeaders.add("mcp-session-id", transportSession.sessionId().get());
133-
}).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then();
134-
};
128+
Function<String, Publisher<Void>> onClose = sessionId -> sessionId == null ? Mono.empty()
129+
: webClient.delete().uri(this.endpoint).headers(httpHeaders -> {
130+
httpHeaders.add("mcp-session-id", sessionId);
131+
})
132+
.retrieve()
133+
.toBodilessEntity()
134+
.doOnError(e -> logger.warn("Got error when closing transport", e))
135+
.then();
135136
return new DefaultMcpTransportSession(onClose);
136137
}
137138

@@ -192,6 +193,7 @@ private Mono<Disposable> reconnect(McpTransportStream<Disposable> stream) {
192193
})
193194
.exchangeToFlux(response -> {
194195
if (isEventStream(response)) {
196+
logger.debug("Established SSE stream via GET");
195197
return eventStream(stream, response);
196198
}
197199
else if (isNotAllowed(response)) {
@@ -208,6 +210,7 @@ else if (isNotFound(response)) {
208210
}).flux();
209211
}
210212
})
213+
.flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage)))
211214
.onErrorComplete(t -> {
212215
this.handleException(t);
213216
return true;
@@ -274,6 +277,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
274277
else {
275278
MediaType mediaType = contentType.get();
276279
if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
280+
logger.debug("Established SSE stream via POST");
277281
// communicate to caller that the message was delivered
278282
sink.success();
279283
// starting a stream

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
2020
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
2121
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
22+
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
23+
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
2224
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
2325
import io.modelcontextprotocol.spec.McpSchema.Prompt;
2426
import io.modelcontextprotocol.spec.McpSchema.Resource;
@@ -77,7 +79,9 @@ McpAsyncClient client(McpClientTransport transport, Function<McpClient.AsyncSpec
7779
McpClient.AsyncSpec builder = McpClient.async(transport)
7880
.requestTimeout(getRequestTimeout())
7981
.initializationTimeout(getInitializationTimeout())
80-
.capabilities(ClientCapabilities.builder().roots(true).build());
82+
.sampling(req -> Mono.just(new CreateMessageResult(McpSchema.Role.USER,
83+
new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN)))
84+
.capabilities(ClientCapabilities.builder().roots(true).sampling().build());
8185
builder = customizer.apply(builder);
8286
client.set(builder.build());
8387
}).doesNotThrowAnyException();
@@ -189,6 +193,22 @@ void testCallTool() {
189193
});
190194
}
191195

196+
@Test
197+
void testSampling() {
198+
withClient(createMcpTransport(), mcpAsyncClient -> {
199+
CallToolRequest callToolRequest = new CallToolRequest("sampleLLM",
200+
Map.of("prompt", "Hello MCP Spring AI!"));
201+
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest)))
202+
.consumeNextWith(callToolResult -> {
203+
assertThat(callToolResult).isNotNull().satisfies(result -> {
204+
assertThat(result.content()).isNotNull();
205+
assertThat(result.isError()).isNull();
206+
});
207+
})
208+
.verifyComplete();
209+
});
210+
}
211+
192212
@Test
193213
void testCallToolWithInvalidTool() {
194214
withClient(createMcpTransport(), mcpAsyncClient -> {
@@ -424,6 +444,20 @@ void testInitializeWithSamplingCapability() {
424444
});
425445
}
426446

447+
@Test
448+
void testInitializeWithElicitationCapability() {
449+
ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build();
450+
ElicitResult elicitResult = ElicitResult.builder()
451+
.message(ElicitResult.Action.ACCEPT)
452+
.content(Map.of("foo", "bar"))
453+
.build();
454+
withClient(createMcpTransport(),
455+
builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)),
456+
client -> {
457+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
458+
});
459+
}
460+
427461
@Test
428462
void testInitializeWithAllCapabilities() {
429463
var capabilities = ClientCapabilities.builder()
@@ -435,7 +469,11 @@ void testInitializeWithAllCapabilities() {
435469
Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler = request -> Mono
436470
.just(CreateMessageResult.builder().message("test").model("test-model").build());
437471

438-
withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler),
472+
Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler = request -> Mono
473+
.just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build());
474+
475+
withClient(createMcpTransport(),
476+
builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler),
439477
client ->
440478

441479
StepVerifier.create(client.initialize()).assertNext(result -> {

mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import java.util.Optional;
1111
import java.util.concurrent.atomic.AtomicBoolean;
1212
import java.util.concurrent.atomic.AtomicReference;
13-
import java.util.function.Supplier;
13+
import java.util.function.Function;
1414

1515
/**
1616
* Default implementation of {@link McpTransportSession} which manages the open
@@ -29,9 +29,9 @@ public class DefaultMcpTransportSession implements McpTransportSession<Disposabl
2929

3030
private final AtomicReference<String> sessionId = new AtomicReference<>();
3131

32-
private final Supplier<Publisher<Void>> onClose;
32+
private final Function<String, Publisher<Void>> onClose;
3333

34-
public DefaultMcpTransportSession(Supplier<Publisher<Void>> onClose) {
34+
public DefaultMcpTransportSession(Function<String, Publisher<Void>> onClose) {
3535
this.onClose = onClose;
3636
}
3737

@@ -73,7 +73,8 @@ public void close() {
7373

7474
@Override
7575
public Mono<Void> closeGracefully() {
76-
return Mono.from(this.onClose.get()).then(Mono.fromRunnable(this.openConnections::dispose));
76+
return Mono.from(this.onClose.apply(this.sessionId.get()))
77+
.then(Mono.fromRunnable(this.openConnections::dispose));
7778
}
7879

7980
}

mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ McpAsyncClient client(McpClientTransport transport, Function<McpClient.AsyncSpec
8080
McpClient.AsyncSpec builder = McpClient.async(transport)
8181
.requestTimeout(getRequestTimeout())
8282
.initializationTimeout(getInitializationTimeout())
83-
.capabilities(ClientCapabilities.builder().roots(true).build());
83+
.sampling(req -> Mono.just(new CreateMessageResult(McpSchema.Role.USER,
84+
new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN)))
85+
.capabilities(ClientCapabilities.builder().roots(true).sampling().build());
8486
builder = customizer.apply(builder);
8587
client.set(builder.build());
8688
}).doesNotThrowAnyException();
@@ -192,6 +194,22 @@ void testCallTool() {
192194
});
193195
}
194196

197+
@Test
198+
void testSampling() {
199+
withClient(createMcpTransport(), mcpAsyncClient -> {
200+
CallToolRequest callToolRequest = new CallToolRequest("sampleLLM",
201+
Map.of("prompt", "Hello MCP Spring AI!"));
202+
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest)))
203+
.consumeNextWith(callToolResult -> {
204+
assertThat(callToolResult).isNotNull().satisfies(result -> {
205+
assertThat(result.content()).isNotNull();
206+
assertThat(result.isError()).isNull();
207+
});
208+
})
209+
.verifyComplete();
210+
});
211+
}
212+
195213
@Test
196214
void testCallToolWithInvalidTool() {
197215
withClient(createMcpTransport(), mcpAsyncClient -> {

0 commit comments

Comments
 (0)