diff --git a/apps/nextjs/src/app/api/chat/errorHandling.test.ts b/apps/nextjs/src/app/api/chat/errorHandling.test.ts index d4fccacf8..573c8550c 100644 --- a/apps/nextjs/src/app/api/chat/errorHandling.test.ts +++ b/apps/nextjs/src/app/api/chat/errorHandling.test.ts @@ -4,8 +4,12 @@ import { UserBannedError } from "@oakai/core/src/models/safetyViolations"; import { TracingSpan } from "@oakai/core/src/tracing/serverTracing"; import { RateLimitExceededError } from "@oakai/core/src/utils/rateLimiting/userBasedRateLimiter"; import { PrismaClientWithAccelerate } from "@oakai/db"; +import invariant from "tiny-invariant"; -import { consumeStream } from "@/utils/testHelpers/consumeStream"; +import { + consumeStream, + extractStreamMessage, +} from "@/utils/testHelpers/consumeStream"; import { handleChatException } from "./errorHandling"; @@ -33,9 +37,9 @@ describe("handleChatException", () => { expect(response.status).toBe(200); - const message = JSON.parse( - await consumeStream(response.body as ReadableStream), - ); + invariant(response.body instanceof ReadableStream); + const message = extractStreamMessage(await consumeStream(response.body)); + expect(message).toEqual({ type: "error", value: "Threat detected", @@ -83,9 +87,9 @@ describe("handleChatException", () => { expect(response.status).toBe(200); - const message = JSON.parse( - await consumeStream(response.body as ReadableStream), - ); + const consumed = await consumeStream(response.body as ReadableStream); + const message = extractStreamMessage(consumed); + expect(message).toEqual({ type: "error", value: "Rate limit exceeded", @@ -110,7 +114,7 @@ describe("handleChatException", () => { expect(response.status).toBe(200); - const message = JSON.parse( + const message = extractStreamMessage( await consumeStream(response.body as ReadableStream), ); expect(message).toEqual({ @@ -119,4 +123,4 @@ describe("handleChatException", () => { }); }); }); -}); +}); \ No newline at end of file diff --git a/apps/nextjs/src/app/api/chat/protocol.ts b/apps/nextjs/src/app/api/chat/protocol.ts index 6fd50d38d..322d43d6a 100644 --- a/apps/nextjs/src/app/api/chat/protocol.ts +++ b/apps/nextjs/src/app/api/chat/protocol.ts @@ -5,7 +5,9 @@ import { import { StreamingTextResponse } from "ai"; export function streamingJSON(message: ErrorDocument | ActionDocument) { - const errorMessage = JSON.stringify(message); + const jsonContent = JSON.stringify(message); + const errorMessage = `0:"${jsonContent.replace(/"/g, '\\"')}"`; + const errorEncoder = new TextEncoder(); return new StreamingTextResponse( diff --git a/apps/nextjs/src/utils/testHelpers/consumeStream.ts b/apps/nextjs/src/utils/testHelpers/consumeStream.ts index 63e59122e..5e9ca8a62 100644 --- a/apps/nextjs/src/utils/testHelpers/consumeStream.ts +++ b/apps/nextjs/src/utils/testHelpers/consumeStream.ts @@ -12,3 +12,12 @@ export async function consumeStream(stream: ReadableStream): Promise { return result; } + +export function extractStreamMessage(streamedText: string) { + const content = streamedText.match(/0:"(.*)"/); + if (!content?.[1]) { + throw new Error("No message found in streamed text"); + } + const strippedContent = content[1].replace(/\\"/g, '"'); + return JSON.parse(strippedContent); +}