diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index a0f2e0bb..b961f6c4 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -5,6 +5,7 @@ import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from "./ import { McpServer } from "./mcp.js"; import { CallToolResult, JSONRPCMessage } from "../types.js"; import { z } from "zod"; +import { AuthInfo } from "./auth/types.js"; /** * Test server configuration for StreamableHTTPServerTransport tests @@ -70,6 +71,61 @@ async function createTestServer(config: TestServerConfig = { sessionIdGenerator: return { server, transport, mcpServer, baseUrl }; } +/** + * Helper to create and start authenticated test HTTP server with MCP setup + */ +async function createTestAuthServer(config: TestServerConfig = { sessionIdGenerator: (() => randomUUID()) }): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; +}> { + const mcpServer = new McpServer( + { name: "test-server", version: "1.0.0" }, + { capabilities: { logging: {} } } + ); + + mcpServer.tool( + "profile", + "A user profile data tool", + { active: z.boolean().describe("Profile status") }, + async ({ active }, { authInfo }): Promise => { + return { content: [{ type: "text", text: `${active ? 'Active' : 'Inactive'} profile from token: ${authInfo?.token}!` }] }; + } + ); + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + enableJsonResponse: config.enableJsonResponse ?? false, + eventStore: config.eventStore + }); + + await mcpServer.connect(transport); + + const server = createServer(async (req: IncomingMessage & { auth?: AuthInfo }, res) => { + try { + if (config.customRequestHandler) { + await config.customRequestHandler(req, res); + } else { + req.auth = { token: req.headers["authorization"]?.split(" ")[1] } as AuthInfo; + await transport.handleRequest(req, res); + } + } catch (error) { + console.error("Error handling request:", error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); + + const baseUrl = await new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, transport, mcpServer, baseUrl }; +} + /** * Helper to stop test server */ @@ -120,10 +176,11 @@ async function readSSEEvent(response: Response): Promise { /** * Helper to send JSON-RPC request */ -async function sendPostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string): Promise { +async function sendPostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record): Promise { const headers: Record = { "Content-Type": "application/json", Accept: "application/json, text/event-stream", + ...extraHeaders }; if (sessionId) { @@ -673,6 +730,105 @@ describe("StreamableHTTPServerTransport", () => { }); }); +describe("StreamableHTTPServerTransport with AuthInfo", () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + + beforeEach(async () => { + const result = await createTestAuthServer(); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + async function initializeServer(): Promise { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + expect(response.status).toBe(200); + const newSessionId = response.headers.get("mcp-session-id"); + expect(newSessionId).toBeDefined(); + return newSessionId as string; + } + + it("should call a tool with authInfo", async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "tools/call", + params: { + name: "profile", + arguments: {active: true}, + }, + id: "call-1", + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId, {'authorization': 'Bearer test-token'}); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split("\n"); + const dataLine = eventLines.find(line => line.startsWith("data:")); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: "2.0", + result: { + content: [ + { + type: "text", + text: "Active profile from token: test-token!", + }, + ], + }, + id: "call-1", + }); + }); + + it("should calls tool without authInfo when it is optional", async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "tools/call", + params: { + name: "profile", + arguments: {active: false}, + }, + id: "call-1", + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split("\n"); + const dataLine = eventLines.find(line => line.startsWith("data:")); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: "2.0", + result: { + content: [ + { + type: "text", + text: "Inactive profile from token: undefined!", + }, + ], + }, + id: "call-1", + }); + }); +}); + // Test JSON Response Mode describe("StreamableHTTPServerTransport with JSON Response Mode", () => { let server: Server; diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index c9051073..65180566 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -4,6 +4,7 @@ import { isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCRespons import getRawBody from "raw-body"; import contentType from "content-type"; import { randomUUID } from "node:crypto"; +import { AuthInfo } from "./auth/types.js"; const MAXIMUM_MESSAGE_SIZE = "4mb"; @@ -112,7 +113,7 @@ export class StreamableHTTPServerTransport implements Transport { sessionId?: string | undefined; onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; constructor(options: StreamableHTTPServerTransportOptions) { this.sessionIdGenerator = options.sessionIdGenerator; @@ -135,7 +136,7 @@ export class StreamableHTTPServerTransport implements Transport { /** * Handles an incoming HTTP request, whether GET or POST */ - async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { + async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { if (req.method === "POST") { await this.handlePostRequest(req, res, parsedBody); } else if (req.method === "GET") { @@ -286,7 +287,7 @@ export class StreamableHTTPServerTransport implements Transport { /** * Handles POST requests containing JSON-RPC messages */ - private async handlePostRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { + private async handlePostRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { try { // Validate the Accept header const acceptHeader = req.headers.accept; @@ -316,6 +317,8 @@ export class StreamableHTTPServerTransport implements Transport { return; } + const authInfo: AuthInfo | undefined = req.auth; + let rawMessage; if (parsedBody !== undefined) { rawMessage = parsedBody; @@ -392,7 +395,7 @@ export class StreamableHTTPServerTransport implements Transport { // handle each message for (const message of messages) { - this.onmessage?.(message); + this.onmessage?.(message, { authInfo }); } } else if (hasRequests) { // The default behavior is to use SSE streaming @@ -427,7 +430,7 @@ export class StreamableHTTPServerTransport implements Transport { // handle each message for (const message of messages) { - this.onmessage?.(message); + this.onmessage?.(message, { authInfo }); } // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses // This will be handled by the send() method when responses are ready