From 90a496463f5606ac2b2524bd1db498e3bf1a300b Mon Sep 17 00:00:00 2001 From: Adam Howard <91115+codeincontext@users.noreply.github.com> Date: Tue, 10 Sep 2024 10:21:36 +0200 Subject: [PATCH] Throw all expected errors up to the route catch handler --- apps/nextjs/src/app/api/chat/chatHandler.ts | 86 +----------- apps/nextjs/src/app/api/chat/config.ts | 13 -- .../src/app/api/chat/errorHandling.test.ts | 122 ++++++++++++++++++ apps/nextjs/src/app/api/chat/errorHandling.ts | 115 +++++++++++++++++ apps/nextjs/src/app/api/chat/route.test.ts | 11 +- apps/nextjs/src/app/api/chat/user.test.ts | 31 +---- apps/nextjs/src/app/api/chat/user.ts | 76 ++++------- .../src/app/api/chat/webActionsPlugin.test.ts | 2 +- .../moderation/moderationErrorHandling.ts | 2 +- 9 files changed, 280 insertions(+), 178 deletions(-) create mode 100644 apps/nextjs/src/app/api/chat/errorHandling.test.ts create mode 100644 apps/nextjs/src/app/api/chat/errorHandling.ts diff --git a/apps/nextjs/src/app/api/chat/chatHandler.ts b/apps/nextjs/src/app/api/chat/chatHandler.ts index 1ac0face2..247fd35db 100644 --- a/apps/nextjs/src/app/api/chat/chatHandler.ts +++ b/apps/nextjs/src/app/api/chat/chatHandler.ts @@ -1,11 +1,6 @@ -import { - Aila, - AilaAuthenticationError, - AilaThreatDetectionError, -} from "@oakai/aila"; +import { Aila } from "@oakai/aila"; import type { AilaOptions, AilaPublicChatOptions, Message } from "@oakai/aila"; import { LooseLessonPlan } from "@oakai/aila/src/protocol/schema"; -import { handleHeliconeError } from "@oakai/aila/src/utils/moderation/moderationErrorHandling"; import { TracingSpan, withTelemetry, @@ -16,7 +11,8 @@ import { NextRequest } from "next/server"; import invariant from "tiny-invariant"; import { Config } from "./config"; -import { streamingJSON } from "./protocol"; +import { handleChatException } from "./errorHandling"; +import { fetchAndCheckUser } from "./user"; export const maxDuration = 300; @@ -61,25 +57,6 @@ async function setupChatHandler(req: NextRequest) { ); } -function reportErrorTelemetry( - span: TracingSpan, - error: Error, - errorType: string, - statusMessage: string, - additionalAttributes: Record< - string, - string | number | boolean | undefined - > = {}, -) { - span.setTag("error", true); - span.setTag("error.type", errorType); - span.setTag("error.message", statusMessage); - span.setTag("error.stack", error.stack); - Object.entries(additionalAttributes).forEach(([key, value]) => { - span.setTag(key, value); - }); -} - function setTelemetryMetadata( span: TracingSpan, id: string, @@ -110,35 +87,6 @@ function handleConnectionAborted(req: NextRequest) { return abortController; } -async function handleThreatDetectionError( - span: TracingSpan, - e: AilaThreatDetectionError, - userId: string, - id: string, - prisma: PrismaClientWithAccelerate, -) { - const heliconeErrorMessage = await handleHeliconeError(userId, id, e, prisma); - reportErrorTelemetry(span, e, "AilaThreatDetectionError", "Threat detected"); - return streamingJSON(heliconeErrorMessage); -} - -async function handleAilaAuthenticationError( - span: TracingSpan, - e: AilaAuthenticationError, -) { - reportErrorTelemetry(span, e, "AilaAuthenticationError", "Unauthorized"); - return new Response("Unauthorized", { status: 401 }); -} - -async function handleGenericError(span: TracingSpan, e: Error) { - reportErrorTelemetry(span, e, e.name, e.message); - return streamingJSON({ - type: "error", - message: e.message, - value: `Sorry, an error occurred: ${e.message}`, - }); -} - async function generateChatStream( aila: Aila, abortController: AbortController, @@ -154,28 +102,6 @@ async function generateChatStream( ); } -async function handleChatException( - span: TracingSpan, - e: unknown, - userId: string | undefined, - chatId: string, - prisma: PrismaClientWithAccelerate, -): Promise { - if (e instanceof AilaAuthenticationError) { - return handleAilaAuthenticationError(span, e); - } - - if (e instanceof AilaThreatDetectionError && userId) { - return handleThreatDetectionError(span, e, userId, chatId, prisma); - } - - if (e instanceof Error) { - return handleGenericError(span, e); - } - - throw e; -} - export async function handleChatPostRequest( req: NextRequest, config: Config, @@ -190,11 +116,7 @@ export async function handleChatPostRequest( let aila: Aila | undefined; try { - const userLookup = await config.handleUserLookup(chatId); - // The user lookup can either return a userId or a response like a streaming protocol message - if ("failureResponse" in userLookup) { - return userLookup.failureResponse; - } + userId = await fetchAndCheckUser(chatId); span.setTag("user_id", userId); aila = await withTelemetry( diff --git a/apps/nextjs/src/app/api/chat/config.ts b/apps/nextjs/src/app/api/chat/config.ts index 74847945c..5c9b91f3f 100644 --- a/apps/nextjs/src/app/api/chat/config.ts +++ b/apps/nextjs/src/app/api/chat/config.ts @@ -5,27 +5,14 @@ import { } from "@oakai/db"; import { nanoid } from "ai"; -import { handleUserLookup as defaultHandleUserLookup } from "./user"; import { createWebActionsPlugin } from "./webActionsPlugin"; export interface Config { - shouldPerformUserLookup: boolean; - mockUserId?: string; - handleUserLookup: (chatId: string) => Promise< - | { - userId: string; - } - | { - failureResponse: Response; - } - >; prisma: PrismaClientWithAccelerate; createAila: (options: Partial) => Promise; } export const defaultConfig: Config = { - shouldPerformUserLookup: true, - handleUserLookup: defaultHandleUserLookup, prisma: globalPrisma, createAila: async (options) => { const webActionsPlugin = createWebActionsPlugin(globalPrisma); diff --git a/apps/nextjs/src/app/api/chat/errorHandling.test.ts b/apps/nextjs/src/app/api/chat/errorHandling.test.ts new file mode 100644 index 000000000..7dddba934 --- /dev/null +++ b/apps/nextjs/src/app/api/chat/errorHandling.test.ts @@ -0,0 +1,122 @@ +import { AilaAuthenticationError, AilaThreatDetectionError } from "@oakai/aila"; +import * as moderationErrorHandling from "@oakai/aila/src/utils/moderation/moderationErrorHandling"; +import { UserBannedError } from "@oakai/core/src/models/safetyViolations"; +import { TracingSpan } from "@oakai/core/src/tracing/serverTracing"; +import { RateLimitExceededError } from "@oakai/core/src/utils/rateLimiting/userBasedRateLimiter"; +import { PrismaClientWithAccelerate } from "@oakai/db"; + +import { consumeStream } from "@/utils/testHelpers/consumeStream"; + +import { handleChatException } from "./errorHandling"; + +describe("handleChatException", () => { + describe("AilaThreatDetectionError", () => { + it("should forward the message from handleHeliconeError", async () => { + jest + .spyOn(moderationErrorHandling, "handleHeliconeError") + .mockResolvedValue({ + type: "error", + value: "Threat detected", + message: "Threat was detected", + }); + + const span = { setTag: jest.fn() } as unknown as TracingSpan; + const error = new AilaThreatDetectionError("test error"); + const prisma = {} as unknown as PrismaClientWithAccelerate; + + const response = await handleChatException( + span, + error, + "test-user-id", + "test-chat-id", + prisma, + ); + + expect(response.status).toBe(200); + + const message = JSON.parse( + await consumeStream(response.body as ReadableStream), + ); + expect(message).toEqual({ + type: "error", + value: "Threat detected", + message: "Threat was detected", + }); + }); + }); + + describe("AilaAuthenticationError", () => { + it("should return an error chat message", async () => { + const span = { setTag: jest.fn() } as unknown as TracingSpan; + const error = new AilaAuthenticationError("test error"); + const prisma = {} as unknown as PrismaClientWithAccelerate; + + const response = await handleChatException( + span, + error, + "test-user-id", + "test-chat-id", + prisma, + ); + + expect(response.status).toBe(401); + + const message = await consumeStream(response.body as ReadableStream); + expect(message).toEqual("Unauthorized"); + }); + }); + + describe("RateLimitExceededError", () => { + it("should return an error chat message", async () => { + const span = { setTag: jest.fn() } as unknown as TracingSpan; + const error = new RateLimitExceededError(100, Date.now() + 3600 * 1000); + const prisma = {} as unknown as PrismaClientWithAccelerate; + + const response = await handleChatException( + span, + error, + "test-user-id", + "test-chat-id", + prisma, + ); + + expect(response.status).toBe(200); + + const message = JSON.parse( + await consumeStream(response.body as ReadableStream), + ); + expect(message).toEqual({ + type: "error", + value: "Rate limit exceeded", + message: + "**Unfortunately you’ve exceeded your fair usage limit for today.** Please come back in 1 hour. If you require a higher limit, please [make a request](https://forms.gle/tHsYMZJR367zydsG8).", + }); + }); + }); + + describe("UserBannedError", () => { + it("should return an error chat message", async () => { + const span = { setTag: jest.fn() } as unknown as TracingSpan; + const error = new UserBannedError("test error"); + const prisma = {} as unknown as PrismaClientWithAccelerate; + + const response = await handleChatException( + span, + error, + "test-user-id", + "test-chat-id", + prisma, + ); + + expect(response.status).toBe(200); + + const message = JSON.parse( + await consumeStream(response.body as ReadableStream), + ); + expect(message).toEqual({ + type: "action", + action: "SHOW_ACCOUNT_LOCKED", + }); + }); + }); +}); diff --git a/apps/nextjs/src/app/api/chat/errorHandling.ts b/apps/nextjs/src/app/api/chat/errorHandling.ts new file mode 100644 index 000000000..898f76942 --- /dev/null +++ b/apps/nextjs/src/app/api/chat/errorHandling.ts @@ -0,0 +1,115 @@ +import { AilaAuthenticationError, AilaThreatDetectionError } from "@oakai/aila"; +import { + ActionDocument, + ErrorDocument, +} from "@oakai/aila/src/protocol/jsonPatchProtocol"; +import { handleHeliconeError } from "@oakai/aila/src/utils/moderation/moderationErrorHandling"; +import { UserBannedError } from "@oakai/core/src/models/safetyViolations"; +import { TracingSpan } from "@oakai/core/src/tracing/serverTracing"; +import { RateLimitExceededError } from "@oakai/core/src/utils/rateLimiting/userBasedRateLimiter"; +import { PrismaClientWithAccelerate } from "@oakai/db"; + +import { streamingJSON } from "./protocol"; + +function reportErrorTelemetry( + span: TracingSpan, + error: Error, + errorType: string, + statusMessage: string, + additionalAttributes: Record< + string, + string | number | boolean | undefined + > = {}, +) { + span.setTag("error", true); + span.setTag("error.type", errorType); + span.setTag("error.message", statusMessage); + span.setTag("error.stack", error.stack); + Object.entries(additionalAttributes).forEach(([key, value]) => { + span.setTag(key, value); + }); +} + +async function handleThreatDetectionError( + span: TracingSpan, + e: AilaThreatDetectionError, + userId: string, + id: string, + prisma: PrismaClientWithAccelerate, +) { + const heliconeErrorMessage = await handleHeliconeError(userId, id, e, prisma); + reportErrorTelemetry(span, e, "AilaThreatDetectionError", "Threat detected"); + return streamingJSON(heliconeErrorMessage); +} + +async function handleAilaAuthenticationError( + span: TracingSpan, + e: AilaAuthenticationError, +) { + reportErrorTelemetry(span, e, "AilaAuthenticationError", "Unauthorized"); + return new Response("Unauthorized", { status: 401 }); +} + +export async function handleRateLimitError( + span: TracingSpan, + error: RateLimitExceededError, +) { + reportErrorTelemetry(span, error, "RateLimitExceededError", "Rate limited"); + + const timeRemainingHours = Math.ceil( + (error.reset - Date.now()) / 1000 / 60 / 60, + ); + const hours = timeRemainingHours === 1 ? "hour" : "hours"; + + return streamingJSON({ + type: "error", + value: error.message, + message: `**Unfortunately you’ve exceeded your fair usage limit for today.** Please come back in ${timeRemainingHours} ${hours}. If you require a higher limit, please [make a request](${process.env.RATELIMIT_FORM_URL}).`, + } as ErrorDocument); +} + +async function handleUserBannedError() { + return streamingJSON({ + type: "action", + action: "SHOW_ACCOUNT_LOCKED", + } as ActionDocument); +} + +async function handleGenericError(span: TracingSpan, e: Error) { + reportErrorTelemetry(span, e, e.name, e.message); + return streamingJSON({ + type: "error", + message: e.message, + value: `Sorry, an error occurred: ${e.message}`, + } as ErrorDocument); +} + +export async function handleChatException( + span: TracingSpan, + e: unknown, + userId: string | undefined, + chatId: string, + prisma: PrismaClientWithAccelerate, +): Promise { + if (e instanceof AilaAuthenticationError) { + return handleAilaAuthenticationError(span, e); + } + + if (e instanceof AilaThreatDetectionError && userId) { + return handleThreatDetectionError(span, e, userId, chatId, prisma); + } + + if (e instanceof RateLimitExceededError && userId) { + return handleRateLimitError(span, e); + } + + if (e instanceof UserBannedError) { + return handleUserBannedError(); + } + + if (e instanceof Error) { + return handleGenericError(span, e); + } + + throw e; +} diff --git a/apps/nextjs/src/app/api/chat/route.test.ts b/apps/nextjs/src/app/api/chat/route.test.ts index 53ee3fec1..18724f1d2 100644 --- a/apps/nextjs/src/app/api/chat/route.test.ts +++ b/apps/nextjs/src/app/api/chat/route.test.ts @@ -12,6 +12,10 @@ import { Config } from "./config"; const chatId = "test-chat-id"; const userId = "test-user-id"; +jest.mock("./user", () => ({ + fetchAndCheckUser: jest.fn().mockResolvedValue("test-user-id"), +})); + describe("Chat API Route", () => { let testConfig: Config; let mockLLMService: MockLLMService; @@ -32,9 +36,6 @@ describe("Chat API Route", () => { jest.spyOn(mockLLMService, "createChatCompletionStream"); testConfig = { - shouldPerformUserLookup: false, - handleUserLookup: jest.fn(), - mockUserId: userId, createAila: jest.fn().mockImplementation(async (options) => { const ailaConfig = { options: { @@ -92,7 +93,5 @@ describe("Chat API Route", () => { expectTracingSpan("chat-api").toHaveBeenExecutedWith({ chat_id: "test-chat-id", }); - - expect(testConfig.handleUserLookup).not.toHaveBeenCalled(); - }, 30000); + }); }); diff --git a/apps/nextjs/src/app/api/chat/user.test.ts b/apps/nextjs/src/app/api/chat/user.test.ts index b3a6ff30d..7b67d9b9e 100644 --- a/apps/nextjs/src/app/api/chat/user.test.ts +++ b/apps/nextjs/src/app/api/chat/user.test.ts @@ -2,7 +2,7 @@ import { inngest } from "@oakai/core"; import { posthogAiBetaServerClient } from "@oakai/core/src/analytics/posthogAiBetaServerClient"; import { RateLimitExceededError } from "@oakai/core/src/utils/rateLimiting/userBasedRateLimiter"; -import { handleRateLimitError } from "./user"; +import { reportRateLimitError } from "./user"; jest.mock("@oakai/core/src/client", () => ({ inngest: { @@ -12,7 +12,7 @@ jest.mock("@oakai/core/src/client", () => ({ })); describe("chat route user functions", () => { - describe("handleRateLimitError", () => { + describe("reportRateLimitError", () => { it("should report rate limit exceeded to PostHog when userId is provided", async () => { jest.spyOn(posthogAiBetaServerClient, "identify"); jest.spyOn(posthogAiBetaServerClient, "capture"); @@ -21,7 +21,7 @@ describe("chat route user functions", () => { const chatId = "testChatId"; const userId = "testUserId"; - await handleRateLimitError(error, userId, chatId); + await reportRateLimitError(error, userId, chatId); expect(posthogAiBetaServerClient.identify).toHaveBeenCalledWith({ distinctId: userId, @@ -52,7 +52,7 @@ describe("chat route user functions", () => { const chatId = "testChatId"; const userId = "testUserId"; - await handleRateLimitError(error, userId, chatId); + await reportRateLimitError(error, userId, chatId); expect(inngest.send).toHaveBeenCalledTimes(1); expect(inngest.send).toHaveBeenCalledWith({ @@ -66,28 +66,5 @@ describe("chat route user functions", () => { }, }); }); - - it("should return an error chat message", async () => { - const mockPosthogClient = { - identify: jest.fn(), - capture: jest.fn(), - shutdown: jest.fn().mockResolvedValue(undefined), - }; - jest.mock("@oakai/core/src/analytics/posthogAiBetaServerClient", () => ({ - posthogAiBetaServerClient: mockPosthogClient, - })); - const error = new RateLimitExceededError(100, Date.now() + 3600 * 1000); - const chatId = "testChatId"; - const userId = "testUserId"; - - const response = await handleRateLimitError(error, userId, chatId); - - expect(response).toEqual({ - type: "error", - value: "Rate limit exceeded", - message: - "**Unfortunately you've exceeded your fair usage limit for today.** Please come back in 1 hour. If you require a higher limit, please [make a request](https://forms.gle/tHsYMZJR367zydsG8).", - }); - }); }); }); diff --git a/apps/nextjs/src/app/api/chat/user.ts b/apps/nextjs/src/app/api/chat/user.ts index 2894fe53c..b39252298 100644 --- a/apps/nextjs/src/app/api/chat/user.ts +++ b/apps/nextjs/src/app/api/chat/user.ts @@ -1,13 +1,13 @@ import { auth, clerkClient } from "@clerk/nextjs/server"; +import { AilaAuthenticationError } from "@oakai/aila"; import { ErrorDocument } from "@oakai/aila/src/protocol/jsonPatchProtocol"; import { demoUsers, inngest } from "@oakai/core"; import { posthogAiBetaServerClient } from "@oakai/core/src/analytics/posthogAiBetaServerClient"; +import { UserBannedError } from "@oakai/core/src/models/safetyViolations"; import { withTelemetry } from "@oakai/core/src/tracing/serverTracing"; import { rateLimits } from "@oakai/core/src/utils/rateLimiting/rateLimit"; import { RateLimitExceededError } from "@oakai/core/src/utils/rateLimiting/userBasedRateLimiter"; -import { streamingJSON } from "./protocol"; - async function checkRateLimit( userId: string, isDemoUser: boolean, @@ -23,7 +23,17 @@ async function checkRateLimit( return null; } catch (e) { if (e instanceof RateLimitExceededError) { - return await handleRateLimitError(e, userId, chatId); + await reportRateLimitError(e, userId, chatId); + + const timeRemainingHours = Math.ceil( + (e.reset - Date.now()) / 1000 / 60 / 60, + ); + span.setTag("error", true); + span.setTag("error.type", "RateLimitExceeded"); + span.setTag("error.message", e.message); + span.setTag("rate_limit.reset_hours", timeRemainingHours); + + throw e; } span.setTag("error", true); if (e instanceof Error) { @@ -34,15 +44,15 @@ async function checkRateLimit( }); } -export async function handleRateLimitError( +export async function reportRateLimitError( error: RateLimitExceededError, userId: string, chatId: string, -): Promise { +): Promise { return withTelemetry( "handle-rate-limit-error", { chatId, userId }, - async (span) => { + async () => { posthogAiBetaServerClient.identify({ distinctId: userId, }); @@ -67,69 +77,39 @@ export async function handleRateLimitError( reset: new Date(error.reset), }, }); - - // Build user-friendly error message - const timeRemainingHours = Math.ceil( - (error.reset - Date.now()) / 1000 / 60 / 60, - ); - const hours = timeRemainingHours === 1 ? "hour" : "hours"; - const higherLimitMessage = process.env.RATELIMIT_FORM_URL - ? ` If you require a higher limit, please [make a request](${process.env.RATELIMIT_FORM_URL}).` - : ""; - - span.setTag("error", true); - span.setTag("error.type", "RateLimitExceeded"); - span.setTag("error.message", error.message); - span.setTag("rate_limit.reset_hours", timeRemainingHours); - - return { - type: "error", - value: error.message, - message: `**Unfortunately you've exceeded your fair usage limit for today.** Please come back in ${timeRemainingHours} ${hours}.${higherLimitMessage}`, - }; }, ); } -export async function fetchAndCheckUser( - chatId: string, -): Promise<{ userId: string } | { failureResponse: Response }> { +export async function fetchAndCheckUser(chatId: string): Promise { return withTelemetry("fetch-and-check-user", { chatId }, async (span) => { const userId = auth().userId; if (!userId) { span.setTag("error", true); span.setTag("error.message", "Unauthorized"); - return { - failureResponse: new Response("Unauthorized", { - status: 401, - }), - }; + throw new AilaAuthenticationError("No user id"); } const clerkUser = await clerkClient.users.getUser(userId); if (clerkUser.banned) { span.setTag("error", true); span.setTag("error.message", "Account locked"); - return { - failureResponse: streamingJSON({ - type: "action", - action: "SHOW_ACCOUNT_LOCKED", - }), - }; + throw new UserBannedError(userId); } const isDemoUser = demoUsers.isDemoUser(clerkUser); - const rateLimitedMessage = await checkRateLimit(userId, isDemoUser, chatId); - if (rateLimitedMessage) { - span.setTag("error", true); - span.setTag("error.message", "Rate limited"); - return { - failureResponse: streamingJSON(rateLimitedMessage), - }; + try { + await checkRateLimit(userId, isDemoUser, chatId); + } catch (e) { + if (e instanceof RateLimitExceededError) { + span.setTag("error", true); + span.setTag("error.message", "Rate limited"); + } + throw e; } span.setTag("user.id", userId); span.setTag("user.demo", isDemoUser); - return { userId }; + return userId; }); } diff --git a/apps/nextjs/src/app/api/chat/webActionsPlugin.test.ts b/apps/nextjs/src/app/api/chat/webActionsPlugin.test.ts index a43a43b62..57386619c 100644 --- a/apps/nextjs/src/app/api/chat/webActionsPlugin.test.ts +++ b/apps/nextjs/src/app/api/chat/webActionsPlugin.test.ts @@ -149,7 +149,7 @@ describe("onStreamError", () => { type: "error", value: "Threat detected", message: - "I wasn't able to process your request because a potentially malicious input was detected.", + "I wasn’t able to process your request because a potentially malicious input was detected.", }); }); diff --git a/packages/aila/src/utils/moderation/moderationErrorHandling.ts b/packages/aila/src/utils/moderation/moderationErrorHandling.ts index 9cb9556f1..82266b3b9 100644 --- a/packages/aila/src/utils/moderation/moderationErrorHandling.ts +++ b/packages/aila/src/utils/moderation/moderationErrorHandling.ts @@ -56,6 +56,6 @@ export async function handleHeliconeError( type: "error", value: "Threat detected", message: - "I wasn't able to process your request because a potentially malicious input was detected.", + "I wasn’t able to process your request because a potentially malicious input was detected.", }; }