Skip to content

test(server): add more tests for SSEServerTransport class #391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

153 changes: 152 additions & 1 deletion src/server/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,44 @@ const createMockResponse = () => {
writeHead: jest.fn<http.ServerResponse['writeHead']>(),
write: jest.fn<http.ServerResponse['write']>().mockReturnValue(true),
on: jest.fn<http.ServerResponse['on']>(),
end: jest.fn<http.ServerResponse['end']>(),
};
res.writeHead.mockReturnThis();
res.on.mockReturnThis();

return res as unknown as http.ServerResponse;
};

const createMockRequest = ({ headers = {}, body }: { headers?: Record<string, string>, body?: string } = {}) => {
const mockReq = {
headers,
body: body ? body : undefined,
auth: {
token: 'test-token',
},
on: jest.fn<http.IncomingMessage['on']>().mockImplementation((event, listener) => {
const mockListener = listener as unknown as (...args: unknown[]) => void;
if (event === 'data') {
mockListener(Buffer.from(body || '') as unknown as Error);
}
if (event === 'error') {
mockListener(new Error('test'));
}
if (event === 'end') {
mockListener();
}
if (event === 'close') {
setTimeout(listener, 100);
}
return mockReq;
}),
listeners: jest.fn<http.IncomingMessage['listeners']>(),
removeListener: jest.fn<http.IncomingMessage['removeListener']>(),
} as unknown as http.IncomingMessage;

return mockReq;
};

describe('SSEServerTransport', () => {
describe('start method', () => {
it('should correctly append sessionId to a simple relative endpoint', async () => {
Expand Down Expand Up @@ -106,4 +137,124 @@ describe('SSEServerTransport', () => {
);
});
});
});

describe('handlePostMessage method', () => {
it('should return 500 if server has not started', async () => {
const mockReq = createMockRequest();
const mockRes = createMockResponse();
const endpoint = '/messages';
const transport = new SSEServerTransport(endpoint, mockRes);

const error = 'SSE connection not established';
await expect(transport.handlePostMessage(mockReq, mockRes))
.rejects.toThrow(error);
expect(mockRes.writeHead).toHaveBeenCalledWith(500);
expect(mockRes.end).toHaveBeenCalledWith(error);
});

it('should return 400 if content-type is not application/json', async () => {
const mockReq = createMockRequest({ headers: { 'content-type': 'text/plain' } });
const mockRes = createMockResponse();
const endpoint = '/messages';
const transport = new SSEServerTransport(endpoint, mockRes);
await transport.start();

transport.onerror = jest.fn();
const error = 'Unsupported content-type: text/plain';
await expect(transport.handlePostMessage(mockReq, mockRes))
.resolves.toBe(undefined);
expect(mockRes.writeHead).toHaveBeenCalledWith(400);
expect(mockRes.end).toHaveBeenCalledWith(expect.stringContaining(error));
expect(transport.onerror).toHaveBeenCalledWith(new Error(error));
});

it('should return 400 if message has not a valid schema', async () => {
const invalidMessage = JSON.stringify({
// missing jsonrpc field
method: 'call',
params: [1, 2, 3],
id: 1,
})
const mockReq = createMockRequest({
headers: { 'content-type': 'application/json' },
body: invalidMessage,
});
const mockRes = createMockResponse();
const endpoint = '/messages';
const transport = new SSEServerTransport(endpoint, mockRes);
await transport.start();

transport.onmessage = jest.fn();
await transport.handlePostMessage(mockReq, mockRes);
expect(mockRes.writeHead).toHaveBeenCalledWith(400);
expect(transport.onmessage).not.toHaveBeenCalled();
expect(mockRes.end).toHaveBeenCalledWith(`Invalid message: ${invalidMessage}`);
});

it('should return 202 if message has a valid schema', async () => {
const validMessage = JSON.stringify({
jsonrpc: "2.0",
method: 'call',
params: {
a: 1,
b: 2,
c: 3,
},
id: 1
})
const mockReq = createMockRequest({
headers: { 'content-type': 'application/json' },
body: validMessage,
});
const mockRes = createMockResponse();
const endpoint = '/messages';
const transport = new SSEServerTransport(endpoint, mockRes);
await transport.start();

transport.onmessage = jest.fn();
await transport.handlePostMessage(mockReq, mockRes);
expect(mockRes.writeHead).toHaveBeenCalledWith(202);
expect(mockRes.end).toHaveBeenCalledWith('Accepted');
expect(transport.onmessage).toHaveBeenCalledWith({
jsonrpc: "2.0",
method: 'call',
params: {
a: 1,
b: 2,
c: 3,
},
id: 1
}, {
authInfo: {
token: 'test-token',
}
});
});
});

describe('close method', () => {
it('should call onclose', async () => {
const mockRes = createMockResponse();
const endpoint = '/messages';
const transport = new SSEServerTransport(endpoint, mockRes);
await transport.start();
transport.onclose = jest.fn();
await transport.close();
expect(transport.onclose).toHaveBeenCalled();
});
});

describe('send method', () => {
it('should call onsend', async () => {
const mockRes = createMockResponse();
const endpoint = '/messages';
const transport = new SSEServerTransport(endpoint, mockRes);
await transport.start();
expect(mockRes.write).toHaveBeenCalledTimes(1);
expect(mockRes.write).toHaveBeenCalledWith(
expect.stringContaining('event: endpoint'));
expect(mockRes.write).toHaveBeenCalledWith(
expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`));
});
});
});
2 changes: 1 addition & 1 deletion src/server/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export class SSEServerTransport implements Transport {
try {
const ct = contentType.parse(req.headers["content-type"] ?? "");
if (ct.type !== "application/json") {
throw new Error(`Unsupported content-type: ${ct}`);
throw new Error(`Unsupported content-type: ${ct.type}`);
}

body = parsedBody ?? await getRawBody(req, {
Expand Down