Skip to content

Commit 72e9d4c

Browse files
LucaButBoringtzolov
andcommitted
Add sampling functionality tests to MCP client test suites (#255)
- Add testSampling() methods to AbstractMcpAsyncClientTests and AbstractMcpSyncClientTests Signed-off-by: Christian Tzolov <[email protected]> Co-authored-by: Christian Tzolov <[email protected]>
1 parent 9a63538 commit 72e9d4c

File tree

4 files changed

+192
-3
lines changed

4 files changed

+192
-3
lines changed

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99
import java.util.Objects;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -38,6 +39,7 @@
3839
import static org.assertj.core.api.Assertions.assertThat;
3940
import static org.assertj.core.api.Assertions.assertThatCode;
4041
import static org.assertj.core.api.Assertions.assertThatThrownBy;
42+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4143

4244
/**
4345
* Test suite for the {@link McpAsyncClient} that can be used with different
@@ -487,4 +489,52 @@ void testLoggingWithNullNotification() {
487489
});
488490
}
489491

492+
@Test
493+
void testSampling() {
494+
McpClientTransport transport = createMcpTransport();
495+
496+
final String message = "Hello, world!";
497+
final String response = "Goodbye, world!";
498+
final int maxTokens = 100;
499+
500+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
501+
AtomicReference<String> receivedMessage = new AtomicReference<>();
502+
AtomicInteger receivedMaxTokens = new AtomicInteger();
503+
504+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
505+
.sampling(request -> {
506+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
507+
request.messages().get(0).content());
508+
receivedPrompt.set(request.systemPrompt());
509+
receivedMessage.set(messageText.text());
510+
receivedMaxTokens.set(request.maxTokens());
511+
512+
return Mono
513+
.just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
514+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN));
515+
}), client -> {
516+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
517+
518+
StepVerifier.create(client.callTool(
519+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))))
520+
.consumeNextWith(result -> {
521+
// Verify tool response to ensure our sampling response was passed
522+
// through
523+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
524+
assertThat(result.content()).allSatisfy(content -> {
525+
if (!(content instanceof McpSchema.TextContent text))
526+
return;
527+
528+
assertThat(text.text()).endsWith(response); // Prefixed
529+
});
530+
531+
// Verify sampling request parameters received in our callback
532+
assertThat(receivedPrompt.get()).isNotEmpty();
533+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
534+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
535+
})
536+
.verifyComplete();
537+
});
538+
}
539+
490540
}

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -31,13 +32,12 @@
3132
import org.junit.jupiter.api.BeforeEach;
3233
import org.junit.jupiter.api.Test;
3334
import reactor.core.publisher.Mono;
34-
import reactor.core.scheduler.Scheduler;
35-
import reactor.core.scheduler.Schedulers;
3635
import reactor.test.StepVerifier;
3736

3837
import static org.assertj.core.api.Assertions.assertThat;
3938
import static org.assertj.core.api.Assertions.assertThatCode;
4039
import static org.assertj.core.api.Assertions.assertThatThrownBy;
40+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4141

4242
/**
4343
* Unit tests for MCP Client Session functionality.
@@ -438,4 +438,48 @@ void testLoggingWithNullNotification() {
438438
.hasMessageContaining("Logging level must not be null"));
439439
}
440440

441+
@Test
442+
void testSampling() {
443+
McpClientTransport transport = createMcpTransport();
444+
445+
final String message = "Hello, world!";
446+
final String response = "Goodbye, world!";
447+
final int maxTokens = 100;
448+
449+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
450+
AtomicReference<String> receivedMessage = new AtomicReference<>();
451+
AtomicInteger receivedMaxTokens = new AtomicInteger();
452+
453+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
454+
.sampling(request -> {
455+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
456+
request.messages().get(0).content());
457+
receivedPrompt.set(request.systemPrompt());
458+
receivedMessage.set(messageText.text());
459+
receivedMaxTokens.set(request.maxTokens());
460+
461+
return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
462+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN);
463+
}), client -> {
464+
client.initialize();
465+
466+
McpSchema.CallToolResult result = client.callTool(
467+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)));
468+
469+
// Verify tool response to ensure our sampling response was passed through
470+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
471+
assertThat(result.content()).allSatisfy(content -> {
472+
if (!(content instanceof McpSchema.TextContent text))
473+
return;
474+
475+
assertThat(text.text()).endsWith(response); // Prefixed
476+
});
477+
478+
// Verify sampling request parameters received in our callback
479+
assertThat(receivedPrompt.get()).isNotEmpty();
480+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
481+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
482+
});
483+
}
484+
441485
}

mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99
import java.util.Objects;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -40,6 +41,7 @@
4041
import static org.assertj.core.api.Assertions.assertThat;
4142
import static org.assertj.core.api.Assertions.assertThatCode;
4243
import static org.assertj.core.api.Assertions.assertThatThrownBy;
44+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4345

4446
/**
4547
* Test suite for the {@link McpAsyncClient} that can be used with different
@@ -468,7 +470,6 @@ void testInitializeWithAllCapabilities() {
468470
// ---------------------------------------
469471
// Logging Tests
470472
// ---------------------------------------
471-
472473
@Test
473474
void testLoggingLevelsWithoutInitialization() {
474475
verifyNotificationSucceedsWithImplicitInitialization(
@@ -508,4 +509,52 @@ void testLoggingWithNullNotification() {
508509
});
509510
}
510511

512+
@Test
513+
void testSampling() {
514+
McpClientTransport transport = createMcpTransport();
515+
516+
final String message = "Hello, world!";
517+
final String response = "Goodbye, world!";
518+
final int maxTokens = 100;
519+
520+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
521+
AtomicReference<String> receivedMessage = new AtomicReference<>();
522+
AtomicInteger receivedMaxTokens = new AtomicInteger();
523+
524+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
525+
.sampling(request -> {
526+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
527+
request.messages().get(0).content());
528+
receivedPrompt.set(request.systemPrompt());
529+
receivedMessage.set(messageText.text());
530+
receivedMaxTokens.set(request.maxTokens());
531+
532+
return Mono
533+
.just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
534+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN));
535+
}), client -> {
536+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
537+
538+
StepVerifier.create(client.callTool(
539+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))))
540+
.consumeNextWith(result -> {
541+
// Verify tool response to ensure our sampling response was passed
542+
// through
543+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
544+
assertThat(result.content()).allSatisfy(content -> {
545+
if (!(content instanceof McpSchema.TextContent text))
546+
return;
547+
548+
assertThat(text.text()).endsWith(response); // Prefixed
549+
});
550+
551+
// Verify sampling request parameters received in our callback
552+
assertThat(receivedPrompt.get()).isNotEmpty();
553+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
554+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
555+
})
556+
.verifyComplete();
557+
});
558+
}
559+
511560
}

mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -38,6 +39,7 @@
3839
import static org.assertj.core.api.Assertions.assertThat;
3940
import static org.assertj.core.api.Assertions.assertThatCode;
4041
import static org.assertj.core.api.Assertions.assertThatThrownBy;
42+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4143

4244
/**
4345
* Unit tests for MCP Client Session functionality.
@@ -439,4 +441,48 @@ void testLoggingWithNullNotification() {
439441
.hasMessageContaining("Logging level must not be null"));
440442
}
441443

444+
@Test
445+
void testSampling() {
446+
McpClientTransport transport = createMcpTransport();
447+
448+
final String message = "Hello, world!";
449+
final String response = "Goodbye, world!";
450+
final int maxTokens = 100;
451+
452+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
453+
AtomicReference<String> receivedMessage = new AtomicReference<>();
454+
AtomicInteger receivedMaxTokens = new AtomicInteger();
455+
456+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
457+
.sampling(request -> {
458+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
459+
request.messages().get(0).content());
460+
receivedPrompt.set(request.systemPrompt());
461+
receivedMessage.set(messageText.text());
462+
receivedMaxTokens.set(request.maxTokens());
463+
464+
return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
465+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN);
466+
}), client -> {
467+
client.initialize();
468+
469+
McpSchema.CallToolResult result = client.callTool(
470+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)));
471+
472+
// Verify tool response to ensure our sampling response was passed through
473+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
474+
assertThat(result.content()).allSatisfy(content -> {
475+
if (!(content instanceof McpSchema.TextContent text))
476+
return;
477+
478+
assertThat(text.text()).endsWith(response); // Prefixed
479+
});
480+
481+
// Verify sampling request parameters received in our callback
482+
assertThat(receivedPrompt.get()).isNotEmpty();
483+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
484+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
485+
});
486+
}
487+
442488
}

0 commit comments

Comments
 (0)