diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index eaac5c71..c4cb0c5d 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -101,6 +101,7 @@ describe("ResourceTemplate", () => { const abortController = new AbortController(); const result = await template.listCallback?.({ signal: abortController.signal, + requestId: 'not-implemented', sendRequest: () => { throw new Error("Not implemented") }, sendNotification: () => { throw new Error("Not implemented") } }); @@ -646,6 +647,59 @@ describe("tool()", () => { expect(receivedSessionId).toBe("test-session-123"); }); + test("should pass requestId to tool callback via RequestHandlerExtra", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + let receivedRequestId: string | number | undefined; + mcpServer.tool("request-id-test", async (extra) => { + receivedRequestId = extra.requestId; + return { + content: [ + { + type: "text", + text: `Received request ID: ${extra.requestId}`, + }, + ], + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "request-id-test", + }, + }, + CallToolResultSchema, + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.content[0].text).toContain("Received request ID:"); + }); + test("should provide sendNotification within tool call", async () => { const mcpServer = new McpServer( { @@ -1702,6 +1756,59 @@ describe("resource()", () => { expect(result.completion.values).toEqual(["movies", "music"]); expect(result.completion.total).toBe(2); }); + + test("should pass requestId to resource callback via RequestHandlerExtra", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + let receivedRequestId: string | number | undefined; + mcpServer.resource("request-id-test", "test://resource", async (_uri, extra) => { + receivedRequestId = extra.requestId; + return { + contents: [ + { + uri: "test://resource", + text: `Received request ID: ${extra.requestId}`, + }, + ], + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/read", + params: { + uri: "test://resource", + }, + }, + ReadResourceResultSchema, + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.contents[0].text).toContain("Received request ID:"); + }); }); describe("prompt()", () => { @@ -2511,4 +2618,60 @@ describe("prompt()", () => { expect(result.completion.values).toEqual(["Alice"]); expect(result.completion.total).toBe(1); }); + + test("should pass requestId to prompt callback via RequestHandlerExtra", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + let receivedRequestId: string | number | undefined; + mcpServer.prompt("request-id-test", async (extra) => { + receivedRequestId = extra.requestId; + return { + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Received request ID: ${extra.requestId}`, + }, + }, + ], + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/get", + params: { + name: "request-id-test", + }, + }, + GetPromptResultSchema, + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.messages[0].content.text).toContain("Received request ID:"); + }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index c9ea79fd..2a6bb7f2 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -115,6 +115,12 @@ export type RequestHandlerExtra this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), authInfo: extra?.authInfo, + requestId: request.id, }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well.