|
8 | 8 | import java.util.List;
|
9 | 9 | import java.util.Map;
|
10 | 10 | import java.util.concurrent.ConcurrentHashMap;
|
| 11 | +import java.util.concurrent.TimeUnit; |
11 | 12 | import java.util.concurrent.atomic.AtomicReference;
|
12 | 13 | import java.util.function.Function;
|
13 | 14 | import java.util.stream.Collectors;
|
|
48 | 49 | import org.springframework.web.reactive.function.server.RouterFunctions;
|
49 | 50 |
|
50 | 51 | import static org.assertj.core.api.Assertions.assertThat;
|
| 52 | +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; |
51 | 53 | import static org.awaitility.Awaitility.await;
|
52 | 54 | import static org.mockito.Mockito.mock;
|
53 | 55 |
|
@@ -196,6 +198,153 @@ void testCreateMessageSuccess(String clientType) {
|
196 | 198 | mcpServer.close();
|
197 | 199 | }
|
198 | 200 |
|
| 201 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 202 | + @ValueSource(strings = { "httpclient", "webflux" }) |
| 203 | + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { |
| 204 | + |
| 205 | + // Client |
| 206 | + var clientBuilder = clientBuilders.get(clientType); |
| 207 | + |
| 208 | + Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> { |
| 209 | + assertThat(request.messages()).hasSize(1); |
| 210 | + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); |
| 211 | + try { |
| 212 | + TimeUnit.SECONDS.sleep(2); |
| 213 | + } |
| 214 | + catch (InterruptedException e) { |
| 215 | + throw new RuntimeException(e); |
| 216 | + } |
| 217 | + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", |
| 218 | + CreateMessageResult.StopReason.STOP_SEQUENCE); |
| 219 | + }; |
| 220 | + |
| 221 | + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) |
| 222 | + .capabilities(ClientCapabilities.builder().sampling().build()) |
| 223 | + .sampling(samplingHandler) |
| 224 | + .build(); |
| 225 | + |
| 226 | + // Server |
| 227 | + |
| 228 | + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), |
| 229 | + null); |
| 230 | + |
| 231 | + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( |
| 232 | + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { |
| 233 | + |
| 234 | + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() |
| 235 | + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, |
| 236 | + new McpSchema.TextContent("Test message")))) |
| 237 | + .modelPreferences(ModelPreferences.builder() |
| 238 | + .hints(List.of()) |
| 239 | + .costPriority(1.0) |
| 240 | + .speedPriority(1.0) |
| 241 | + .intelligencePriority(1.0) |
| 242 | + .build()) |
| 243 | + .build(); |
| 244 | + |
| 245 | + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { |
| 246 | + assertThat(result).isNotNull(); |
| 247 | + assertThat(result.role()).isEqualTo(Role.USER); |
| 248 | + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); |
| 249 | + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); |
| 250 | + assertThat(result.model()).isEqualTo("MockModelName"); |
| 251 | + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); |
| 252 | + }).verifyComplete(); |
| 253 | + |
| 254 | + return Mono.just(callResponse); |
| 255 | + }); |
| 256 | + |
| 257 | + var mcpServer = McpServer.async(mcpServerTransportProvider) |
| 258 | + .requestTimeout(Duration.ofSeconds(4)) |
| 259 | + .serverInfo("test-server", "1.0.0") |
| 260 | + .tools(tool) |
| 261 | + .build(); |
| 262 | + |
| 263 | + InitializeResult initResult = mcpClient.initialize(); |
| 264 | + assertThat(initResult).isNotNull(); |
| 265 | + |
| 266 | + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); |
| 267 | + |
| 268 | + assertThat(response).isNotNull(); |
| 269 | + assertThat(response).isEqualTo(callResponse); |
| 270 | + |
| 271 | + mcpClient.close(); |
| 272 | + mcpServer.close(); |
| 273 | + } |
| 274 | + |
| 275 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 276 | + @ValueSource(strings = { "httpclient", "webflux" }) |
| 277 | + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { |
| 278 | + |
| 279 | + // Client |
| 280 | + var clientBuilder = clientBuilders.get(clientType); |
| 281 | + |
| 282 | + Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> { |
| 283 | + assertThat(request.messages()).hasSize(1); |
| 284 | + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); |
| 285 | + try { |
| 286 | + TimeUnit.SECONDS.sleep(3); |
| 287 | + } |
| 288 | + catch (InterruptedException e) { |
| 289 | + throw new RuntimeException(e); |
| 290 | + } |
| 291 | + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", |
| 292 | + CreateMessageResult.StopReason.STOP_SEQUENCE); |
| 293 | + }; |
| 294 | + |
| 295 | + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) |
| 296 | + .capabilities(ClientCapabilities.builder().sampling().build()) |
| 297 | + .sampling(samplingHandler) |
| 298 | + .build(); |
| 299 | + |
| 300 | + // Server |
| 301 | + |
| 302 | + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), |
| 303 | + null); |
| 304 | + |
| 305 | + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( |
| 306 | + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { |
| 307 | + |
| 308 | + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() |
| 309 | + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, |
| 310 | + new McpSchema.TextContent("Test message")))) |
| 311 | + .modelPreferences(ModelPreferences.builder() |
| 312 | + .hints(List.of()) |
| 313 | + .costPriority(1.0) |
| 314 | + .speedPriority(1.0) |
| 315 | + .intelligencePriority(1.0) |
| 316 | + .build()) |
| 317 | + .build(); |
| 318 | + |
| 319 | + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { |
| 320 | + assertThat(result).isNotNull(); |
| 321 | + assertThat(result.role()).isEqualTo(Role.USER); |
| 322 | + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); |
| 323 | + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); |
| 324 | + assertThat(result.model()).isEqualTo("MockModelName"); |
| 325 | + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); |
| 326 | + }).verifyComplete(); |
| 327 | + |
| 328 | + return Mono.just(callResponse); |
| 329 | + }); |
| 330 | + |
| 331 | + var mcpServer = McpServer.async(mcpServerTransportProvider) |
| 332 | + .requestTimeout(Duration.ofSeconds(1)) |
| 333 | + .serverInfo("test-server", "1.0.0") |
| 334 | + .tools(tool) |
| 335 | + .build(); |
| 336 | + |
| 337 | + InitializeResult initResult = mcpClient.initialize(); |
| 338 | + assertThat(initResult).isNotNull(); |
| 339 | + |
| 340 | + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { |
| 341 | + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); |
| 342 | + }).withMessageContaining("Timeout"); |
| 343 | + |
| 344 | + mcpClient.close(); |
| 345 | + mcpServer.close(); |
| 346 | + } |
| 347 | + |
199 | 348 | // ---------------------------------------
|
200 | 349 | // Roots Tests
|
201 | 350 | // ---------------------------------------
|
|
0 commit comments