Skip to content

Commit

Permalink
feat: introduce hasBroadcastAccess
Browse files Browse the repository at this point in the history
  • Loading branch information
nikgraf committed Sep 28, 2023
1 parent 195bbe0 commit 390306b
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 40 deletions.
2 changes: 2 additions & 0 deletions examples/backend/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ async function main() {
createSnapshot: createSnapshotDb,
createUpdate: createUpdateDb,
hasAccess: async () => true,
hasBroadcastAccess: async ({ websocketSessionKeys }) =>
websocketSessionKeys.map(() => true),
logging: "error",
})
);
Expand Down
49 changes: 41 additions & 8 deletions packages/secsync/src/server/createWebSocketConnection.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ it("should handle document error if URL is undefined", async () => {
const mockCreateSnapshot = jest.fn();
const mockCreateUpdate = jest.fn();
const mockHasAccess = jest.fn().mockReturnValue(true);
const mockHasBroadcastAccess = jest
.fn()
.mockImplementation((websocketSessionKeys) =>
websocketSessionKeys.map(() => true)
);

const connection = createWebSocketConnection({
getDocument: mockGetDocument,
createSnapshot: mockCreateSnapshot,
createUpdate: mockCreateUpdate,
hasAccess: mockHasAccess,
hasBroadcastAccess: mockHasBroadcastAccess,
});

await connection(mockWs, mockReq);
Expand All @@ -45,27 +51,34 @@ it("should handle document error if URL is undefined", async () => {
expect(mockWs.close).toHaveBeenCalledTimes(1);
expect(removeConnection).toHaveBeenCalledWith({
documentId: "",
currentClientConnection: mockWs,
websocket: mockWs,
});
});

it("should close connection if unauthorized for read access", async () => {
mockReq.url = "/test-document";
mockReq.url = "/test-document?sessionKey=123";

const mockHasAccess = jest.fn().mockReturnValue(false);
const mockHasBroadcastAccess = jest
.fn()
.mockImplementation((websocketSessionKeys) =>
websocketSessionKeys.map(() => true)
);

const connection = createWebSocketConnection({
getDocument: jest.fn(),
createSnapshot: jest.fn(),
createUpdate: jest.fn(),
hasAccess: mockHasAccess,
hasBroadcastAccess: mockHasBroadcastAccess,
});

await connection(mockWs, mockReq);

expect(mockHasAccess).toHaveBeenCalledWith({
action: "read",
documentId: "test-document",
websocketSessionKey: "123",
});
expect(mockWs.send).toHaveBeenCalledWith(
JSON.stringify({ type: "unauthorized" })
Expand All @@ -74,16 +87,22 @@ it("should close connection if unauthorized for read access", async () => {
});

it("should close connection if document not found", async () => {
mockReq.url = "/test-document";
mockReq.url = "/test-document?sessionKey=123";

const mockGetDocument = jest.fn().mockReturnValue(undefined);
const mockHasAccess = jest.fn().mockReturnValue(true);
const mockHasBroadcastAccess = jest
.fn()
.mockImplementation((websocketSessionKeys) =>
websocketSessionKeys.map(() => true)
);

const connection = createWebSocketConnection({
getDocument: mockGetDocument,
createSnapshot: jest.fn(),
createUpdate: jest.fn(),
hasAccess: mockHasAccess,
hasBroadcastAccess: mockHasBroadcastAccess,
});

await connection(mockWs, mockReq);
Expand All @@ -95,7 +114,7 @@ it("should close connection if document not found", async () => {
});

it("should add connection and send document if found", async () => {
mockReq.url = "/test-document";
mockReq.url = "/test-document?sessionKey=123";

const mockDocument = {
snapshot: {},
Expand All @@ -105,12 +124,18 @@ it("should add connection and send document if found", async () => {

const mockGetDocument = jest.fn().mockReturnValue(mockDocument);
const mockHasAccess = jest.fn().mockReturnValue(true);
const mockHasBroadcastAccess = jest
.fn()
.mockImplementation((websocketSessionKeys) =>
websocketSessionKeys.map(() => true)
);

const connection = createWebSocketConnection({
getDocument: mockGetDocument,
createSnapshot: jest.fn(),
createUpdate: jest.fn(),
hasAccess: mockHasAccess,
hasBroadcastAccess: mockHasBroadcastAccess,
});

await connection(mockWs, mockReq);
Expand All @@ -123,7 +148,8 @@ it("should add connection and send document if found", async () => {
});
expect(addConnection).toHaveBeenCalledWith({
documentId: "test-document",
currentClientConnection: mockWs,
websocket: mockWs,
websocketSessionKey: "123",
});
expect(mockWs.send).toHaveBeenCalledWith(
JSON.stringify({ type: "document", ...mockDocument })
Expand All @@ -139,15 +165,22 @@ it("should properly parse and send through knownSnapshotId & knownSnapshotUpdate

const mockGetDocument = jest.fn().mockReturnValue(mockDocument);
const mockHasAccess = jest.fn().mockReturnValue(true);
const mockHasBroadcastAccess = jest
.fn()
.mockImplementation((websocketSessionKeys) =>
websocketSessionKeys.map(() => true)
);

const connection = createWebSocketConnection({
getDocument: mockGetDocument,
createSnapshot: jest.fn(),
createUpdate: jest.fn(),
hasAccess: mockHasAccess,
hasBroadcastAccess: mockHasBroadcastAccess,
});

mockReq.url = "/test-document?knownSnapshotId=123";
mockReq.url = "/test-document?sessionKey=123&knownSnapshotId=123";

await connection(mockWs, mockReq);

expect(mockGetDocument).toHaveBeenCalledWith({
Expand All @@ -157,7 +190,7 @@ it("should properly parse and send through knownSnapshotId & knownSnapshotUpdate
mode: "complete",
});

mockReq.url = "/test-document?knownSnapshotId=555";
mockReq.url = "/test-document?sessionKey=123&knownSnapshotId=555";
await connection(mockWs, mockReq);

expect(mockGetDocument).toHaveBeenCalledWith({
Expand All @@ -171,7 +204,7 @@ it("should properly parse and send through knownSnapshotId & knownSnapshotUpdate
const knownSnapshotUpdateClocksQuery = encodeURIComponent(
JSON.stringify(knownSnapshotUpdateClocks)
);
mockReq.url = `/test-document?knownSnapshotId=42&knownSnapshotUpdateClocks=${knownSnapshotUpdateClocksQuery}`;
mockReq.url = `/test-document?sessionKey=123&knownSnapshotId=42&knownSnapshotUpdateClocks=${knownSnapshotUpdateClocksQuery}`;
await connection(mockWs, mockReq);

expect(mockGetDocument).toHaveBeenCalledWith({
Expand Down
29 changes: 23 additions & 6 deletions packages/secsync/src/server/createWebSocketConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
CreateUpdateParams,
GetDocumentParams,
HasAccessParams,
HasBroadcastAccessParams,
Snapshot,
SnapshotProofInfo,
SnapshotUpdateClocks,
Expand All @@ -39,6 +40,9 @@ type WebsocketConnectionParams = {
createSnapshot(createSnapshotParams: CreateSnapshotParams): Promise<Snapshot>;
createUpdate(createUpdateParams: CreateUpdateParams): Promise<Update>;
hasAccess(hasAccessParams: HasAccessParams): Promise<boolean>;
hasBroadcastAccess(
hasBroadcastAccessParams: HasBroadcastAccessParams
): Promise<boolean[]>;
additionalAuthenticationDataValidations?: AdditionalAuthenticationDataValidations;
/** default: "off" */
logging?: "off" | "error";
Expand All @@ -50,6 +54,7 @@ export const createWebSocketConnection =
createSnapshot,
createUpdate,
hasAccess,
hasBroadcastAccess,
additionalAuthenticationDataValidations,
logging: loggingParam,
}: WebsocketConnectionParams) =>
Expand All @@ -60,7 +65,10 @@ export const createWebSocketConnection =
const handleDocumentError = () => {
connection.send(JSON.stringify({ type: "document-error" }));
connection.close();
removeConnection({ documentId, currentClientConnection: connection });
removeConnection({
documentId,
websocket: connection,
});
};

try {
Expand All @@ -75,6 +83,12 @@ export const createWebSocketConnection =
? urlParts.query.sessionKey[0]
: urlParts.query.sessionKey;

// invalid connection without a sessionKey
if (websocketSessionKey === undefined) {
handleDocumentError();
return;
}

const getDocumentModeString = Array.isArray(urlParts.query.mode)
? urlParts.query.mode[0]
: urlParts.query.mode;
Expand Down Expand Up @@ -127,7 +141,7 @@ export const createWebSocketConnection =
return;
}

addConnection({ documentId, currentClientConnection: connection });
addConnection({ documentId, websocket: connection, websocketSessionKey });
connection.send(JSON.stringify({ type: "document", ...doc }));

connection.on("message", async function message(messageContent) {
Expand Down Expand Up @@ -195,7 +209,8 @@ export const createWebSocketConnection =
type: "snapshot",
snapshot: snapshotMsgForOtherClients,
},
currentClientConnection: connection,
currentWebsocket: connection,
hasBroadcastAccess,
});
} catch (error) {
if (logging === "error") {
Expand Down Expand Up @@ -345,7 +360,8 @@ export const createWebSocketConnection =
broadcastMessage({
documentId,
message: { ...savedUpdate, type: "update" },
currentClientConnection: connection,
currentWebsocket: connection,
hasBroadcastAccess,
});
} catch (err) {
if (logging === "error") {
Expand Down Expand Up @@ -408,7 +424,8 @@ export const createWebSocketConnection =
...ephemeralMessageMessage,
type: "ephemeral-message",
},
currentClientConnection: connection,
currentWebsocket: connection,
hasBroadcastAccess,
});
} catch (err) {
console.error("Ephemeral message failed due:", err);
Expand All @@ -417,7 +434,7 @@ export const createWebSocketConnection =
});

connection.on("close", function () {
removeConnection({ documentId, currentClientConnection: connection });
removeConnection({ documentId, websocket: connection });
});
} catch (error) {
if (logging === "error") {
Expand Down
88 changes: 62 additions & 26 deletions packages/secsync/src/server/store.ts
Original file line number Diff line number Diff line change
@@ -1,58 +1,94 @@
type DocumentStoreEntry = {
connections: Set<any>;
};
import WebSocket from "ws";
import { HasBroadcastAccessParams } from "../types";

type ConnectionEntry = { websocketSessionKey: string; websocket: WebSocket };

const documents: { [key: string]: DocumentStoreEntry } = {};
const documents: { [key: string]: ConnectionEntry[] } = {};
const messageQueues: { [key: string]: BroadcastMessageParams[] } = {};

export type BroadcastMessageParams = {
documentId: string;
message: any;
currentClientConnection: any;
currentWebsocket: any;
hasBroadcastAccess: (params: HasBroadcastAccessParams) => Promise<boolean[]>;
};

export const broadcastMessage = ({
documentId,
message,
currentClientConnection,
}: BroadcastMessageParams) => {
documents[documentId]?.connections?.forEach((conn) => {
if (currentClientConnection !== conn) {
conn.send(JSON.stringify(message));
}
// for debugging purposes
// conn.send(JSON.stringify(update));
export const broadcastMessage = async (params: BroadcastMessageParams) => {
const { documentId } = params;

if (!messageQueues[documentId]) {
messageQueues[documentId] = [];
}

messageQueues[documentId].push(params);

// only start processing if this is the only message in the queue to avoid overlapping calls
if (messageQueues[documentId].length === 1) {
processMessageQueue(documentId);
}
};

const processMessageQueue = async (documentId: string) => {
if (!documents[documentId] || messageQueues[documentId].length === 0) return;

const { hasBroadcastAccess } = messageQueues[documentId][0];

const websocketSessionKeys = documents[documentId].map(
({ websocketSessionKey }) => websocketSessionKey
);

const accessResults = await hasBroadcastAccess({
documentId,
websocketSessionKeys,
});

documents[documentId] = documents[documentId].filter(
(_, index) => accessResults[index]
);

// Send all the messages in the queue to the allowed connections
messageQueues[documentId].forEach(({ message, currentWebsocket }) => {
documents[documentId].forEach(({ websocket }) => {
if (websocket !== currentWebsocket) {
websocket.send(JSON.stringify(message));
}
});
});

// clear the message queue after it's broadcasted
messageQueues[documentId] = [];
};

export type AddConnectionParams = {
documentId: string;
currentClientConnection: any;
websocket: WebSocket;
websocketSessionKey: string;
};

export const addConnection = ({
documentId,
currentClientConnection,
websocket,
websocketSessionKey,
}: AddConnectionParams) => {
if (documents[documentId]) {
documents[documentId].connections.add(currentClientConnection);
documents[documentId].push({ websocket, websocketSessionKey });
} else {
documents[documentId] = {
connections: new Set<any>(),
};
documents[documentId].connections.add(currentClientConnection);
documents[documentId] = [{ websocket, websocketSessionKey }];
}
};

export type RemoveConnectionParams = {
documentId: string;
currentClientConnection: any;
websocket: WebSocket;
};

export const removeConnection = ({
documentId,
currentClientConnection,
websocket: currentWebsocket,
}: RemoveConnectionParams) => {
if (documents[documentId]) {
documents[documentId].connections.delete(currentClientConnection);
documents[documentId] = documents[documentId].filter(
({ websocket }) => websocket !== currentWebsocket
);
}
};
Loading

0 comments on commit 390306b

Please sign in to comment.