Skip to content

Commit

Permalink
fix(core): Refactor push sessionId validation, and add unit tests (no…
Browse files Browse the repository at this point in the history
…-changelog) (#8815)
  • Loading branch information
netroy committed Mar 6, 2024
1 parent b5ffb7d commit c03f08d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 23 deletions.
37 changes: 18 additions & 19 deletions packages/cli/src/push/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import type { PushResponse, SSEPushRequest, WebSocketPushRequest } from './types
import type { IPushDataType } from '@/Interfaces';
import type { User } from '@db/entities/User';
import { OnShutdown } from '@/decorators/OnShutdown';
import { BadRequestError } from '@/errors/response-errors/bad-request.error';

const useWebSockets = config.getEnv('push.backend') === 'websocket';

Expand All @@ -39,14 +40,24 @@ export class Push extends EventEmitter {

handleRequest(req: SSEPushRequest | WebSocketPushRequest, res: PushResponse) {
const {
userId,
user,
ws,
query: { sessionId },
} = req;

if (!sessionId) {
if (ws) {
ws.send('The query parameter "sessionId" is missing!');
ws.close(1008);
return;
}
throw new BadRequestError('The query parameter "sessionId" is missing!');
}

if (req.ws) {
(this.backend as WebSocketPush).add(sessionId, userId, req.ws);
(this.backend as WebSocketPush).add(sessionId, user.id, req.ws);
} else if (!useWebSockets) {
(this.backend as SSEPush).add(sessionId, userId, { req, res });
(this.backend as SSEPush).add(sessionId, user.id, { req, res });
} else {
res.status(401).send('Unauthorized');
return;
Expand Down Expand Up @@ -103,29 +114,17 @@ export const setupPushServer = (restEndpoint: string, server: Server, app: Appli
export const setupPushHandler = (restEndpoint: string, app: Application) => {
const endpoint = `/${restEndpoint}/push`;

const pushValidationMiddleware: RequestHandler = async (
const authMiddleware: RequestHandler = async (
req: SSEPushRequest | WebSocketPushRequest,
res,
next,
) => {
const ws = req.ws;

const { sessionId } = req.query;
if (sessionId === undefined) {
if (ws) {
ws.send('The query parameter "sessionId" is missing!');
ws.close(1008);
} else {
next(new Error('The query parameter "sessionId" is missing!'));
}
return;
}
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-member-access
const authCookie: string = req.cookies?.[AUTH_COOKIE_NAME] ?? '';
const user = await resolveJwt(authCookie);
req.userId = user.id;
req.user = await resolveJwt(authCookie);
} catch (error) {
const ws = req.ws;
if (ws) {
ws.send(`Unauthorized: ${(error as Error).message}`);
ws.close(1008);
Expand All @@ -141,7 +140,7 @@ export const setupPushHandler = (restEndpoint: string, app: Application) => {
const push = Container.get(Push);
app.use(
endpoint,
pushValidationMiddleware,
authMiddleware,
(req: SSEPushRequest | WebSocketPushRequest, res: PushResponse) => push.handleRequest(req, res),
);
};
9 changes: 5 additions & 4 deletions packages/cli/src/push/types.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import type { AuthenticatedRequest } from '@/requests';
import type { User } from '@db/entities/User';
import type { Request, Response } from 'express';
import type { Response } from 'express';
import type { WebSocket } from 'ws';

// TODO: move all push related types here

export type PushRequest = Request<{}, {}, {}, { sessionId: string }>;
export type PushRequest = AuthenticatedRequest<{}, {}, {}, { sessionId: string }>;

export type SSEPushRequest = PushRequest & { ws: undefined; userId: User['id'] };
export type WebSocketPushRequest = PushRequest & { ws: WebSocket; userId: User['id'] };
export type SSEPushRequest = PushRequest & { ws: undefined };
export type WebSocketPushRequest = PushRequest & { ws: WebSocket };

export type PushResponse = Response & { req: PushRequest };

Expand Down
42 changes: 42 additions & 0 deletions packages/cli/test/unit/push/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import type { WebSocket } from 'ws';
import config from '@/config';
import type { User } from '@db/entities/User';
import { Push } from '@/push';
import { SSEPush } from '@/push/sse.push';
import { WebSocketPush } from '@/push/websocket.push';
import type { WebSocketPushRequest, SSEPushRequest } from '@/push/types';
import { mockInstance } from '../../shared/mocking';
import { mock } from 'jest-mock-extended';
import { BadRequestError } from '@/errors/response-errors/bad-request.error';

jest.unmock('@/push');

describe('Push', () => {
const user = mock<User>();

const sseBackend = mockInstance(SSEPush);
const wsBackend = mockInstance(WebSocketPush);

test('should validate sessionId on requests for websocket backend', () => {
config.set('push.backend', 'websocket');
const push = new Push();
const ws = mock<WebSocket>();
const request = mock<WebSocketPushRequest>({ user, ws });
request.query = { sessionId: '' };
push.handleRequest(request, mock());

expect(ws.send).toHaveBeenCalled();
expect(ws.close).toHaveBeenCalledWith(1008);
expect(wsBackend.add).not.toHaveBeenCalled();
});

test('should validate sessionId on requests for SSE backend', () => {
config.set('push.backend', 'sse');
const push = new Push();
const request = mock<SSEPushRequest>({ user, ws: undefined });
request.query = { sessionId: '' };
expect(() => push.handleRequest(request, mock())).toThrow(BadRequestError);

expect(sseBackend.add).not.toHaveBeenCalled();
});
});

0 comments on commit c03f08d

Please sign in to comment.