Skip to content

Commit f3b0774

Browse files
LucaButBoringtzolov
andcommitted
Add sampling functionality tests to MCP client test suites (#255)
- Add testSampling() methods to AbstractMcpAsyncClientTests and AbstractMcpSyncClientTests Signed-off-by: Christian Tzolov <[email protected]> Co-authored-by: Christian Tzolov <[email protected]>
1 parent 8e9857f commit f3b0774

File tree

5 files changed

+194
-37
lines changed

5 files changed

+194
-37
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package io.modelcontextprotocol.client;
22

3-
import com.fasterxml.jackson.databind.ObjectMapper;
43
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
54
import io.modelcontextprotocol.spec.McpClientTransport;
65
import org.junit.jupiter.api.Timeout;
7-
import org.springframework.web.reactive.function.client.WebClient;
86
import org.testcontainers.containers.GenericContainer;
97
import org.testcontainers.containers.wait.strategy.Wait;
108

9+
import org.springframework.web.reactive.function.client.WebClient;
10+
1111
@Timeout(15)
1212
public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests {
1313

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99
import java.util.Objects;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -40,6 +41,7 @@
4041
import static org.assertj.core.api.Assertions.assertThat;
4142
import static org.assertj.core.api.Assertions.assertThatCode;
4243
import static org.assertj.core.api.Assertions.assertThatThrownBy;
44+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4345

4446
/**
4547
* Test suite for the {@link McpAsyncClient} that can be used with different
@@ -193,22 +195,6 @@ void testCallTool() {
193195
});
194196
}
195197

196-
@Test
197-
void testSampling() {
198-
withClient(createMcpTransport(), mcpAsyncClient -> {
199-
CallToolRequest callToolRequest = new CallToolRequest("sampleLLM",
200-
Map.of("prompt", "Hello MCP Spring AI!"));
201-
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest)))
202-
.consumeNextWith(callToolResult -> {
203-
assertThat(callToolResult).isNotNull().satisfies(result -> {
204-
assertThat(result.content()).isNotNull();
205-
assertThat(result.isError()).isNull();
206-
});
207-
})
208-
.verifyComplete();
209-
});
210-
}
211-
212198
@Test
213199
void testCallToolWithInvalidTool() {
214200
withClient(createMcpTransport(), mcpAsyncClient -> {
@@ -525,4 +511,52 @@ void testLoggingWithNullNotification() {
525511
});
526512
}
527513

514+
@Test
515+
void testSampling() {
516+
McpClientTransport transport = createMcpTransport();
517+
518+
final String message = "Hello, world!";
519+
final String response = "Goodbye, world!";
520+
final int maxTokens = 100;
521+
522+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
523+
AtomicReference<String> receivedMessage = new AtomicReference<>();
524+
AtomicInteger receivedMaxTokens = new AtomicInteger();
525+
526+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
527+
.sampling(request -> {
528+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
529+
request.messages().get(0).content());
530+
receivedPrompt.set(request.systemPrompt());
531+
receivedMessage.set(messageText.text());
532+
receivedMaxTokens.set(request.maxTokens());
533+
534+
return Mono
535+
.just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
536+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN));
537+
}), client -> {
538+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
539+
540+
StepVerifier.create(client.callTool(
541+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))))
542+
.consumeNextWith(result -> {
543+
// Verify tool response to ensure our sampling response was passed
544+
// through
545+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
546+
assertThat(result.content()).allSatisfy(content -> {
547+
if (!(content instanceof McpSchema.TextContent text))
548+
return;
549+
550+
assertThat(text.text()).endsWith(response); // Prefixed
551+
});
552+
553+
// Verify sampling request parameters received in our callback
554+
assertThat(receivedPrompt.get()).isNotEmpty();
555+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
556+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
557+
})
558+
.verifyComplete();
559+
});
560+
}
561+
528562
}

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -31,13 +32,12 @@
3132
import org.junit.jupiter.api.BeforeEach;
3233
import org.junit.jupiter.api.Test;
3334
import reactor.core.publisher.Mono;
34-
import reactor.core.scheduler.Scheduler;
35-
import reactor.core.scheduler.Schedulers;
3635
import reactor.test.StepVerifier;
3736

3837
import static org.assertj.core.api.Assertions.assertThat;
3938
import static org.assertj.core.api.Assertions.assertThatCode;
4039
import static org.assertj.core.api.Assertions.assertThatThrownBy;
40+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4141

4242
/**
4343
* Unit tests for MCP Client Session functionality.
@@ -438,4 +438,48 @@ void testLoggingWithNullNotification() {
438438
.hasMessageContaining("Logging level must not be null"));
439439
}
440440

441+
@Test
442+
void testSampling() {
443+
McpClientTransport transport = createMcpTransport();
444+
445+
final String message = "Hello, world!";
446+
final String response = "Goodbye, world!";
447+
final int maxTokens = 100;
448+
449+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
450+
AtomicReference<String> receivedMessage = new AtomicReference<>();
451+
AtomicInteger receivedMaxTokens = new AtomicInteger();
452+
453+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
454+
.sampling(request -> {
455+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
456+
request.messages().get(0).content());
457+
receivedPrompt.set(request.systemPrompt());
458+
receivedMessage.set(messageText.text());
459+
receivedMaxTokens.set(request.maxTokens());
460+
461+
return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
462+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN);
463+
}), client -> {
464+
client.initialize();
465+
466+
McpSchema.CallToolResult result = client.callTool(
467+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)));
468+
469+
// Verify tool response to ensure our sampling response was passed through
470+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
471+
assertThat(result.content()).allSatisfy(content -> {
472+
if (!(content instanceof McpSchema.TextContent text))
473+
return;
474+
475+
assertThat(text.text()).endsWith(response); // Prefixed
476+
});
477+
478+
// Verify sampling request parameters received in our callback
479+
assertThat(receivedPrompt.get()).isNotEmpty();
480+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
481+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
482+
});
483+
}
484+
441485
}

mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99
import java.util.Objects;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -40,6 +41,7 @@
4041
import static org.assertj.core.api.Assertions.assertThat;
4142
import static org.assertj.core.api.Assertions.assertThatCode;
4243
import static org.assertj.core.api.Assertions.assertThatThrownBy;
44+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4345

4446
/**
4547
* Test suite for the {@link McpAsyncClient} that can be used with different
@@ -194,22 +196,6 @@ void testCallTool() {
194196
});
195197
}
196198

197-
@Test
198-
void testSampling() {
199-
withClient(createMcpTransport(), mcpAsyncClient -> {
200-
CallToolRequest callToolRequest = new CallToolRequest("sampleLLM",
201-
Map.of("prompt", "Hello MCP Spring AI!"));
202-
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest)))
203-
.consumeNextWith(callToolResult -> {
204-
assertThat(callToolResult).isNotNull().satisfies(result -> {
205-
assertThat(result.content()).isNotNull();
206-
assertThat(result.isError()).isNull();
207-
});
208-
})
209-
.verifyComplete();
210-
});
211-
}
212-
213199
@Test
214200
void testCallToolWithInvalidTool() {
215201
withClient(createMcpTransport(), mcpAsyncClient -> {
@@ -486,7 +472,6 @@ void testInitializeWithAllCapabilities() {
486472
// ---------------------------------------
487473
// Logging Tests
488474
// ---------------------------------------
489-
490475
@Test
491476
void testLoggingLevelsWithoutInitialization() {
492477
verifyNotificationSucceedsWithImplicitInitialization(
@@ -526,4 +511,52 @@ void testLoggingWithNullNotification() {
526511
});
527512
}
528513

514+
@Test
515+
void testSampling() {
516+
McpClientTransport transport = createMcpTransport();
517+
518+
final String message = "Hello, world!";
519+
final String response = "Goodbye, world!";
520+
final int maxTokens = 100;
521+
522+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
523+
AtomicReference<String> receivedMessage = new AtomicReference<>();
524+
AtomicInteger receivedMaxTokens = new AtomicInteger();
525+
526+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
527+
.sampling(request -> {
528+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
529+
request.messages().get(0).content());
530+
receivedPrompt.set(request.systemPrompt());
531+
receivedMessage.set(messageText.text());
532+
receivedMaxTokens.set(request.maxTokens());
533+
534+
return Mono
535+
.just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
536+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN));
537+
}), client -> {
538+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
539+
540+
StepVerifier.create(client.callTool(
541+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))))
542+
.consumeNextWith(result -> {
543+
// Verify tool response to ensure our sampling response was passed
544+
// through
545+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
546+
assertThat(result.content()).allSatisfy(content -> {
547+
if (!(content instanceof McpSchema.TextContent text))
548+
return;
549+
550+
assertThat(text.text()).endsWith(response); // Prefixed
551+
});
552+
553+
// Verify sampling request parameters received in our callback
554+
assertThat(receivedPrompt.get()).isNotEmpty();
555+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
556+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
557+
})
558+
.verifyComplete();
559+
});
560+
}
561+
529562
}

mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -38,6 +39,7 @@
3839
import static org.assertj.core.api.Assertions.assertThat;
3940
import static org.assertj.core.api.Assertions.assertThatCode;
4041
import static org.assertj.core.api.Assertions.assertThatThrownBy;
42+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4143

4244
/**
4345
* Unit tests for MCP Client Session functionality.
@@ -439,4 +441,48 @@ void testLoggingWithNullNotification() {
439441
.hasMessageContaining("Logging level must not be null"));
440442
}
441443

444+
@Test
445+
void testSampling() {
446+
McpClientTransport transport = createMcpTransport();
447+
448+
final String message = "Hello, world!";
449+
final String response = "Goodbye, world!";
450+
final int maxTokens = 100;
451+
452+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
453+
AtomicReference<String> receivedMessage = new AtomicReference<>();
454+
AtomicInteger receivedMaxTokens = new AtomicInteger();
455+
456+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
457+
.sampling(request -> {
458+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
459+
request.messages().get(0).content());
460+
receivedPrompt.set(request.systemPrompt());
461+
receivedMessage.set(messageText.text());
462+
receivedMaxTokens.set(request.maxTokens());
463+
464+
return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
465+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN);
466+
}), client -> {
467+
client.initialize();
468+
469+
McpSchema.CallToolResult result = client.callTool(
470+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)));
471+
472+
// Verify tool response to ensure our sampling response was passed through
473+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
474+
assertThat(result.content()).allSatisfy(content -> {
475+
if (!(content instanceof McpSchema.TextContent text))
476+
return;
477+
478+
assertThat(text.text()).endsWith(response); // Prefixed
479+
});
480+
481+
// Verify sampling request parameters received in our callback
482+
assertThat(receivedPrompt.get()).isNotEmpty();
483+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
484+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
485+
});
486+
}
487+
442488
}

0 commit comments

Comments
 (0)