From 2deffb6c194154a2fb99dfcc3e53443debfda3a6 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 23 Apr 2025 22:10:12 -0700 Subject: [PATCH] test(server): add more tests forSSEServerTransport class --- package-lock.json | 4 +- src/server/sse.test.ts | 153 ++++++++++++++++++++++++++++++++++++++++- src/server/sse.ts | 2 +- 3 files changed, 155 insertions(+), 4 deletions(-) diff --git a/package-lock.json b/package-lock.json index 1165b751..3c6e2d90 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.9.0", + "version": "1.10.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.9.0", + "version": "1.10.2", "license": "MIT", "dependencies": { "content-type": "^1.0.5", diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 2fd2c042..11705fe4 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -7,6 +7,7 @@ const createMockResponse = () => { writeHead: jest.fn(), write: jest.fn().mockReturnValue(true), on: jest.fn(), + end: jest.fn(), }; res.writeHead.mockReturnThis(); res.on.mockReturnThis(); @@ -14,6 +15,36 @@ const createMockResponse = () => { return res as unknown as http.ServerResponse; }; +const createMockRequest = ({ headers = {}, body }: { headers?: Record, body?: string } = {}) => { + const mockReq = { + headers, + body: body ? body : undefined, + auth: { + token: 'test-token', + }, + on: jest.fn().mockImplementation((event, listener) => { + const mockListener = listener as unknown as (...args: unknown[]) => void; + if (event === 'data') { + mockListener(Buffer.from(body || '') as unknown as Error); + } + if (event === 'error') { + mockListener(new Error('test')); + } + if (event === 'end') { + mockListener(); + } + if (event === 'close') { + setTimeout(listener, 100); + } + return mockReq; + }), + listeners: jest.fn(), + removeListener: jest.fn(), + } as unknown as http.IncomingMessage; + + return mockReq; +}; + describe('SSEServerTransport', () => { describe('start method', () => { it('should correctly append sessionId to a simple relative endpoint', async () => { @@ -106,4 +137,124 @@ describe('SSEServerTransport', () => { ); }); }); -}); + + describe('handlePostMessage method', () => { + it('should return 500 if server has not started', async () => { + const mockReq = createMockRequest(); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + + const error = 'SSE connection not established'; + await expect(transport.handlePostMessage(mockReq, mockRes)) + .rejects.toThrow(error); + expect(mockRes.writeHead).toHaveBeenCalledWith(500); + expect(mockRes.end).toHaveBeenCalledWith(error); + }); + + it('should return 400 if content-type is not application/json', async () => { + const mockReq = createMockRequest({ headers: { 'content-type': 'text/plain' } }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onerror = jest.fn(); + const error = 'Unsupported content-type: text/plain'; + await expect(transport.handlePostMessage(mockReq, mockRes)) + .resolves.toBe(undefined); + expect(mockRes.writeHead).toHaveBeenCalledWith(400); + expect(mockRes.end).toHaveBeenCalledWith(expect.stringContaining(error)); + expect(transport.onerror).toHaveBeenCalledWith(new Error(error)); + }); + + it('should return 400 if message has not a valid schema', async () => { + const invalidMessage = JSON.stringify({ + // missing jsonrpc field + method: 'call', + params: [1, 2, 3], + id: 1, + }) + const mockReq = createMockRequest({ + headers: { 'content-type': 'application/json' }, + body: invalidMessage, + }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onmessage = jest.fn(); + await transport.handlePostMessage(mockReq, mockRes); + expect(mockRes.writeHead).toHaveBeenCalledWith(400); + expect(transport.onmessage).not.toHaveBeenCalled(); + expect(mockRes.end).toHaveBeenCalledWith(`Invalid message: ${invalidMessage}`); + }); + + it('should return 202 if message has a valid schema', async () => { + const validMessage = JSON.stringify({ + jsonrpc: "2.0", + method: 'call', + params: { + a: 1, + b: 2, + c: 3, + }, + id: 1 + }) + const mockReq = createMockRequest({ + headers: { 'content-type': 'application/json' }, + body: validMessage, + }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onmessage = jest.fn(); + await transport.handlePostMessage(mockReq, mockRes); + expect(mockRes.writeHead).toHaveBeenCalledWith(202); + expect(mockRes.end).toHaveBeenCalledWith('Accepted'); + expect(transport.onmessage).toHaveBeenCalledWith({ + jsonrpc: "2.0", + method: 'call', + params: { + a: 1, + b: 2, + c: 3, + }, + id: 1 + }, { + authInfo: { + token: 'test-token', + } + }); + }); + }); + + describe('close method', () => { + it('should call onclose', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + transport.onclose = jest.fn(); + await transport.close(); + expect(transport.onclose).toHaveBeenCalled(); + }); + }); + + describe('send method', () => { + it('should call onsend', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + expect.stringContaining('event: endpoint')); + expect(mockRes.write).toHaveBeenCalledWith( + expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`)); + }); + }); +}); \ No newline at end of file diff --git a/src/server/sse.ts b/src/server/sse.ts index 03f6fefc..164780ef 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -92,7 +92,7 @@ export class SSEServerTransport implements Transport { try { const ct = contentType.parse(req.headers["content-type"] ?? ""); if (ct.type !== "application/json") { - throw new Error(`Unsupported content-type: ${ct}`); + throw new Error(`Unsupported content-type: ${ct.type}`); } body = parsedBody ?? await getRawBody(req, {