Skip to content

[FEAT] added support for AuthInfo in extra for StreamableHTTPServerTransport #399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 157 additions & 1 deletion src/server/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from "./
import { McpServer } from "./mcp.js";
import { CallToolResult, JSONRPCMessage } from "../types.js";
import { z } from "zod";
import { AuthInfo } from "./auth/types.js";

/**
* Test server configuration for StreamableHTTPServerTransport tests
Expand Down Expand Up @@ -70,6 +71,61 @@ async function createTestServer(config: TestServerConfig = { sessionIdGenerator:
return { server, transport, mcpServer, baseUrl };
}

/**
* Helper to create and start authenticated test HTTP server with MCP setup
*/
async function createTestAuthServer(config: TestServerConfig = { sessionIdGenerator: (() => randomUUID()) }): Promise<{
server: Server;
transport: StreamableHTTPServerTransport;
mcpServer: McpServer;
baseUrl: URL;
}> {
const mcpServer = new McpServer(
{ name: "test-server", version: "1.0.0" },
{ capabilities: { logging: {} } }
);

mcpServer.tool(
"profile",
"A user profile data tool",
{ active: z.boolean().describe("Profile status") },
async ({ active }, { authInfo }): Promise<CallToolResult> => {
return { content: [{ type: "text", text: `${active ? 'Active' : 'Inactive'} profile from token: ${authInfo?.token}!` }] };
}
);

const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: config.sessionIdGenerator,
enableJsonResponse: config.enableJsonResponse ?? false,
eventStore: config.eventStore
});

await mcpServer.connect(transport);

const server = createServer(async (req: IncomingMessage & { auth?: AuthInfo }, res) => {
try {
if (config.customRequestHandler) {
await config.customRequestHandler(req, res);
} else {
req.auth = { token: req.headers["authorization"]?.split(" ")[1] } as AuthInfo;
await transport.handleRequest(req, res);
}
} catch (error) {
console.error("Error handling request:", error);
if (!res.headersSent) res.writeHead(500).end();
}
});

const baseUrl = await new Promise<URL>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
resolve(new URL(`http://127.0.0.1:${addr.port}`));
});
});

return { server, transport, mcpServer, baseUrl };
}

/**
* Helper to stop test server
*/
Expand Down Expand Up @@ -120,10 +176,11 @@ async function readSSEEvent(response: Response): Promise<string> {
/**
* Helper to send JSON-RPC request
*/
async function sendPostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string): Promise<Response> {
async function sendPostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record<string, string>): Promise<Response> {
const headers: Record<string, string> = {
"Content-Type": "application/json",
Accept: "application/json, text/event-stream",
...extraHeaders
};

if (sessionId) {
Expand Down Expand Up @@ -673,6 +730,105 @@ describe("StreamableHTTPServerTransport", () => {
});
});

describe("StreamableHTTPServerTransport with AuthInfo", () => {
let server: Server;
let transport: StreamableHTTPServerTransport;
let baseUrl: URL;
let sessionId: string;

beforeEach(async () => {
const result = await createTestAuthServer();
server = result.server;
transport = result.transport;
baseUrl = result.baseUrl;
});

afterEach(async () => {
await stopTestServer({ server, transport });
});

async function initializeServer(): Promise<string> {
const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize);

expect(response.status).toBe(200);
const newSessionId = response.headers.get("mcp-session-id");
expect(newSessionId).toBeDefined();
return newSessionId as string;
}

it("should call a tool with authInfo", async () => {
sessionId = await initializeServer();

const toolCallMessage: JSONRPCMessage = {
jsonrpc: "2.0",
method: "tools/call",
params: {
name: "profile",
arguments: {active: true},
},
id: "call-1",
};

const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId, {'authorization': 'Bearer test-token'});
expect(response.status).toBe(200);

const text = await readSSEEvent(response);
const eventLines = text.split("\n");
const dataLine = eventLines.find(line => line.startsWith("data:"));
expect(dataLine).toBeDefined();

const eventData = JSON.parse(dataLine!.substring(5));
expect(eventData).toMatchObject({
jsonrpc: "2.0",
result: {
content: [
{
type: "text",
text: "Active profile from token: test-token!",
},
],
},
id: "call-1",
});
});

it("should calls tool without authInfo when it is optional", async () => {
sessionId = await initializeServer();

const toolCallMessage: JSONRPCMessage = {
jsonrpc: "2.0",
method: "tools/call",
params: {
name: "profile",
arguments: {active: false},
},
id: "call-1",
};

const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId);
expect(response.status).toBe(200);

const text = await readSSEEvent(response);
const eventLines = text.split("\n");
const dataLine = eventLines.find(line => line.startsWith("data:"));
expect(dataLine).toBeDefined();

const eventData = JSON.parse(dataLine!.substring(5));
expect(eventData).toMatchObject({
jsonrpc: "2.0",
result: {
content: [
{
type: "text",
text: "Inactive profile from token: undefined!",
},
],
},
id: "call-1",
});
});
});

// Test JSON Response Mode
describe("StreamableHTTPServerTransport with JSON Response Mode", () => {
let server: Server;
Expand Down
13 changes: 8 additions & 5 deletions src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCRespons
import getRawBody from "raw-body";
import contentType from "content-type";
import { randomUUID } from "node:crypto";
import { AuthInfo } from "./auth/types.js";

const MAXIMUM_MESSAGE_SIZE = "4mb";

Expand Down Expand Up @@ -112,7 +113,7 @@ export class StreamableHTTPServerTransport implements Transport {
sessionId?: string | undefined;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;

constructor(options: StreamableHTTPServerTransportOptions) {
this.sessionIdGenerator = options.sessionIdGenerator;
Expand All @@ -135,7 +136,7 @@ export class StreamableHTTPServerTransport implements Transport {
/**
* Handles an incoming HTTP request, whether GET or POST
*/
async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise<void> {
async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise<void> {
if (req.method === "POST") {
await this.handlePostRequest(req, res, parsedBody);
} else if (req.method === "GET") {
Expand Down Expand Up @@ -286,7 +287,7 @@ export class StreamableHTTPServerTransport implements Transport {
/**
* Handles POST requests containing JSON-RPC messages
*/
private async handlePostRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise<void> {
private async handlePostRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise<void> {
try {
// Validate the Accept header
const acceptHeader = req.headers.accept;
Expand Down Expand Up @@ -316,6 +317,8 @@ export class StreamableHTTPServerTransport implements Transport {
return;
}

const authInfo: AuthInfo | undefined = req.auth;

let rawMessage;
if (parsedBody !== undefined) {
rawMessage = parsedBody;
Expand Down Expand Up @@ -392,7 +395,7 @@ export class StreamableHTTPServerTransport implements Transport {

// handle each message
for (const message of messages) {
this.onmessage?.(message);
this.onmessage?.(message, { authInfo });
}
} else if (hasRequests) {
// The default behavior is to use SSE streaming
Expand Down Expand Up @@ -427,7 +430,7 @@ export class StreamableHTTPServerTransport implements Transport {

// handle each message
for (const message of messages) {
this.onmessage?.(message);
this.onmessage?.(message, { authInfo });
}
// The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses
// This will be handled by the send() method when responses are ready
Expand Down