diff --git a/examples/backend/src/index.ts b/examples/backend/src/index.ts index 44fdb01..5326757 100644 --- a/examples/backend/src/index.ts +++ b/examples/backend/src/index.ts @@ -28,6 +28,8 @@ async function main() { createSnapshot: createSnapshotDb, createUpdate: createUpdateDb, hasAccess: async () => true, + hasBroadcastAccess: async ({ websocketSessionKeys }) => + websocketSessionKeys.map(() => true), logging: "error", }) ); diff --git a/packages/secsync/src/server/createWebSocketConnection.test.ts b/packages/secsync/src/server/createWebSocketConnection.test.ts index 1a8e1e2..e2b8b02 100644 --- a/packages/secsync/src/server/createWebSocketConnection.test.ts +++ b/packages/secsync/src/server/createWebSocketConnection.test.ts @@ -29,12 +29,18 @@ it("should handle document error if URL is undefined", async () => { const mockCreateSnapshot = jest.fn(); const mockCreateUpdate = jest.fn(); const mockHasAccess = jest.fn().mockReturnValue(true); + const mockHasBroadcastAccess = jest + .fn() + .mockImplementation((websocketSessionKeys) => + websocketSessionKeys.map(() => true) + ); const connection = createWebSocketConnection({ getDocument: mockGetDocument, createSnapshot: mockCreateSnapshot, createUpdate: mockCreateUpdate, hasAccess: mockHasAccess, + hasBroadcastAccess: mockHasBroadcastAccess, }); await connection(mockWs, mockReq); @@ -45,20 +51,26 @@ it("should handle document error if URL is undefined", async () => { expect(mockWs.close).toHaveBeenCalledTimes(1); expect(removeConnection).toHaveBeenCalledWith({ documentId: "", - currentClientConnection: mockWs, + websocket: mockWs, }); }); it("should close connection if unauthorized for read access", async () => { - mockReq.url = "/test-document"; + mockReq.url = "/test-document?sessionKey=123"; const mockHasAccess = jest.fn().mockReturnValue(false); + const mockHasBroadcastAccess = jest + .fn() + .mockImplementation((websocketSessionKeys) => + websocketSessionKeys.map(() => true) + ); const connection = createWebSocketConnection({ getDocument: jest.fn(), createSnapshot: jest.fn(), createUpdate: jest.fn(), hasAccess: mockHasAccess, + hasBroadcastAccess: mockHasBroadcastAccess, }); await connection(mockWs, mockReq); @@ -66,6 +78,7 @@ it("should close connection if unauthorized for read access", async () => { expect(mockHasAccess).toHaveBeenCalledWith({ action: "read", documentId: "test-document", + websocketSessionKey: "123", }); expect(mockWs.send).toHaveBeenCalledWith( JSON.stringify({ type: "unauthorized" }) @@ -74,16 +87,22 @@ it("should close connection if unauthorized for read access", async () => { }); it("should close connection if document not found", async () => { - mockReq.url = "/test-document"; + mockReq.url = "/test-document?sessionKey=123"; const mockGetDocument = jest.fn().mockReturnValue(undefined); const mockHasAccess = jest.fn().mockReturnValue(true); + const mockHasBroadcastAccess = jest + .fn() + .mockImplementation((websocketSessionKeys) => + websocketSessionKeys.map(() => true) + ); const connection = createWebSocketConnection({ getDocument: mockGetDocument, createSnapshot: jest.fn(), createUpdate: jest.fn(), hasAccess: mockHasAccess, + hasBroadcastAccess: mockHasBroadcastAccess, }); await connection(mockWs, mockReq); @@ -95,7 +114,7 @@ it("should close connection if document not found", async () => { }); it("should add connection and send document if found", async () => { - mockReq.url = "/test-document"; + mockReq.url = "/test-document?sessionKey=123"; const mockDocument = { snapshot: {}, @@ -105,12 +124,18 @@ it("should add connection and send document if found", async () => { const mockGetDocument = jest.fn().mockReturnValue(mockDocument); const mockHasAccess = jest.fn().mockReturnValue(true); + const mockHasBroadcastAccess = jest + .fn() + .mockImplementation((websocketSessionKeys) => + websocketSessionKeys.map(() => true) + ); const connection = createWebSocketConnection({ getDocument: mockGetDocument, createSnapshot: jest.fn(), createUpdate: jest.fn(), hasAccess: mockHasAccess, + hasBroadcastAccess: mockHasBroadcastAccess, }); await connection(mockWs, mockReq); @@ -123,7 +148,8 @@ it("should add connection and send document if found", async () => { }); expect(addConnection).toHaveBeenCalledWith({ documentId: "test-document", - currentClientConnection: mockWs, + websocket: mockWs, + websocketSessionKey: "123", }); expect(mockWs.send).toHaveBeenCalledWith( JSON.stringify({ type: "document", ...mockDocument }) @@ -139,15 +165,22 @@ it("should properly parse and send through knownSnapshotId & knownSnapshotUpdate const mockGetDocument = jest.fn().mockReturnValue(mockDocument); const mockHasAccess = jest.fn().mockReturnValue(true); + const mockHasBroadcastAccess = jest + .fn() + .mockImplementation((websocketSessionKeys) => + websocketSessionKeys.map(() => true) + ); const connection = createWebSocketConnection({ getDocument: mockGetDocument, createSnapshot: jest.fn(), createUpdate: jest.fn(), hasAccess: mockHasAccess, + hasBroadcastAccess: mockHasBroadcastAccess, }); - mockReq.url = "/test-document?knownSnapshotId=123"; + mockReq.url = "/test-document?sessionKey=123&knownSnapshotId=123"; + await connection(mockWs, mockReq); expect(mockGetDocument).toHaveBeenCalledWith({ @@ -157,7 +190,7 @@ it("should properly parse and send through knownSnapshotId & knownSnapshotUpdate mode: "complete", }); - mockReq.url = "/test-document?knownSnapshotId=555"; + mockReq.url = "/test-document?sessionKey=123&knownSnapshotId=555"; await connection(mockWs, mockReq); expect(mockGetDocument).toHaveBeenCalledWith({ @@ -171,7 +204,7 @@ it("should properly parse and send through knownSnapshotId & knownSnapshotUpdate const knownSnapshotUpdateClocksQuery = encodeURIComponent( JSON.stringify(knownSnapshotUpdateClocks) ); - mockReq.url = `/test-document?knownSnapshotId=42&knownSnapshotUpdateClocks=${knownSnapshotUpdateClocksQuery}`; + mockReq.url = `/test-document?sessionKey=123&knownSnapshotId=42&knownSnapshotUpdateClocks=${knownSnapshotUpdateClocksQuery}`; await connection(mockWs, mockReq); expect(mockGetDocument).toHaveBeenCalledWith({ diff --git a/packages/secsync/src/server/createWebSocketConnection.ts b/packages/secsync/src/server/createWebSocketConnection.ts index bde60e1..111736b 100644 --- a/packages/secsync/src/server/createWebSocketConnection.ts +++ b/packages/secsync/src/server/createWebSocketConnection.ts @@ -16,6 +16,7 @@ import { CreateUpdateParams, GetDocumentParams, HasAccessParams, + HasBroadcastAccessParams, Snapshot, SnapshotProofInfo, SnapshotUpdateClocks, @@ -39,6 +40,9 @@ type WebsocketConnectionParams = { createSnapshot(createSnapshotParams: CreateSnapshotParams): Promise; createUpdate(createUpdateParams: CreateUpdateParams): Promise; hasAccess(hasAccessParams: HasAccessParams): Promise; + hasBroadcastAccess( + hasBroadcastAccessParams: HasBroadcastAccessParams + ): Promise; additionalAuthenticationDataValidations?: AdditionalAuthenticationDataValidations; /** default: "off" */ logging?: "off" | "error"; @@ -50,6 +54,7 @@ export const createWebSocketConnection = createSnapshot, createUpdate, hasAccess, + hasBroadcastAccess, additionalAuthenticationDataValidations, logging: loggingParam, }: WebsocketConnectionParams) => @@ -60,7 +65,10 @@ export const createWebSocketConnection = const handleDocumentError = () => { connection.send(JSON.stringify({ type: "document-error" })); connection.close(); - removeConnection({ documentId, currentClientConnection: connection }); + removeConnection({ + documentId, + websocket: connection, + }); }; try { @@ -75,6 +83,12 @@ export const createWebSocketConnection = ? urlParts.query.sessionKey[0] : urlParts.query.sessionKey; + // invalid connection without a sessionKey + if (websocketSessionKey === undefined) { + handleDocumentError(); + return; + } + const getDocumentModeString = Array.isArray(urlParts.query.mode) ? urlParts.query.mode[0] : urlParts.query.mode; @@ -127,7 +141,7 @@ export const createWebSocketConnection = return; } - addConnection({ documentId, currentClientConnection: connection }); + addConnection({ documentId, websocket: connection, websocketSessionKey }); connection.send(JSON.stringify({ type: "document", ...doc })); connection.on("message", async function message(messageContent) { @@ -195,7 +209,8 @@ export const createWebSocketConnection = type: "snapshot", snapshot: snapshotMsgForOtherClients, }, - currentClientConnection: connection, + currentWebsocket: connection, + hasBroadcastAccess, }); } catch (error) { if (logging === "error") { @@ -345,7 +360,8 @@ export const createWebSocketConnection = broadcastMessage({ documentId, message: { ...savedUpdate, type: "update" }, - currentClientConnection: connection, + currentWebsocket: connection, + hasBroadcastAccess, }); } catch (err) { if (logging === "error") { @@ -408,7 +424,8 @@ export const createWebSocketConnection = ...ephemeralMessageMessage, type: "ephemeral-message", }, - currentClientConnection: connection, + currentWebsocket: connection, + hasBroadcastAccess, }); } catch (err) { console.error("Ephemeral message failed due:", err); @@ -417,7 +434,7 @@ export const createWebSocketConnection = }); connection.on("close", function () { - removeConnection({ documentId, currentClientConnection: connection }); + removeConnection({ documentId, websocket: connection }); }); } catch (error) { if (logging === "error") { diff --git a/packages/secsync/src/server/store.ts b/packages/secsync/src/server/store.ts index 62a9179..6b72537 100644 --- a/packages/secsync/src/server/store.ts +++ b/packages/secsync/src/server/store.ts @@ -1,58 +1,94 @@ -type DocumentStoreEntry = { - connections: Set; -}; +import WebSocket from "ws"; +import { HasBroadcastAccessParams } from "../types"; + +type ConnectionEntry = { websocketSessionKey: string; websocket: WebSocket }; -const documents: { [key: string]: DocumentStoreEntry } = {}; +const documents: { [key: string]: ConnectionEntry[] } = {}; +const messageQueues: { [key: string]: BroadcastMessageParams[] } = {}; export type BroadcastMessageParams = { documentId: string; message: any; - currentClientConnection: any; + currentWebsocket: any; + hasBroadcastAccess: (params: HasBroadcastAccessParams) => Promise; }; -export const broadcastMessage = ({ - documentId, - message, - currentClientConnection, -}: BroadcastMessageParams) => { - documents[documentId]?.connections?.forEach((conn) => { - if (currentClientConnection !== conn) { - conn.send(JSON.stringify(message)); - } - // for debugging purposes - // conn.send(JSON.stringify(update)); +export const broadcastMessage = async (params: BroadcastMessageParams) => { + const { documentId } = params; + + if (!messageQueues[documentId]) { + messageQueues[documentId] = []; + } + + messageQueues[documentId].push(params); + + // only start processing if this is the only message in the queue to avoid overlapping calls + if (messageQueues[documentId].length === 1) { + processMessageQueue(documentId); + } +}; + +const processMessageQueue = async (documentId: string) => { + if (!documents[documentId] || messageQueues[documentId].length === 0) return; + + const { hasBroadcastAccess } = messageQueues[documentId][0]; + + const websocketSessionKeys = documents[documentId].map( + ({ websocketSessionKey }) => websocketSessionKey + ); + + const accessResults = await hasBroadcastAccess({ + documentId, + websocketSessionKeys, }); + + documents[documentId] = documents[documentId].filter( + (_, index) => accessResults[index] + ); + + // Send all the messages in the queue to the allowed connections + messageQueues[documentId].forEach(({ message, currentWebsocket }) => { + documents[documentId].forEach(({ websocket }) => { + if (websocket !== currentWebsocket) { + websocket.send(JSON.stringify(message)); + } + }); + }); + + // clear the message queue after it's broadcasted + messageQueues[documentId] = []; }; export type AddConnectionParams = { documentId: string; - currentClientConnection: any; + websocket: WebSocket; + websocketSessionKey: string; }; export const addConnection = ({ documentId, - currentClientConnection, + websocket, + websocketSessionKey, }: AddConnectionParams) => { if (documents[documentId]) { - documents[documentId].connections.add(currentClientConnection); + documents[documentId].push({ websocket, websocketSessionKey }); } else { - documents[documentId] = { - connections: new Set(), - }; - documents[documentId].connections.add(currentClientConnection); + documents[documentId] = [{ websocket, websocketSessionKey }]; } }; export type RemoveConnectionParams = { documentId: string; - currentClientConnection: any; + websocket: WebSocket; }; export const removeConnection = ({ documentId, - currentClientConnection, + websocket: currentWebsocket, }: RemoveConnectionParams) => { if (documents[documentId]) { - documents[documentId].connections.delete(currentClientConnection); + documents[documentId] = documents[documentId].filter( + ({ websocket }) => websocket !== currentWebsocket + ); } }; diff --git a/packages/secsync/src/types.ts b/packages/secsync/src/types.ts index d239890..0ad134e 100644 --- a/packages/secsync/src/types.ts +++ b/packages/secsync/src/types.ts @@ -186,6 +186,11 @@ export type HasAccessParams = websocketSessionKey: string | undefined; }; +export type HasBroadcastAccessParams = { + documentId: string; + websocketSessionKeys: string[]; +}; + export type ValidSessions = { [authorPublicKey: string]: { sessionId: string; sessionCounter: number }; };