Skip to content

Commit

Permalink
Throw all expected errors up to the route catch handler
Browse files Browse the repository at this point in the history
  • Loading branch information
codeincontext committed Sep 10, 2024
1 parent e001c7d commit 90a4964
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 178 deletions.
86 changes: 4 additions & 82 deletions apps/nextjs/src/app/api/chat/chatHandler.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import {
Aila,
AilaAuthenticationError,
AilaThreatDetectionError,
} from "@oakai/aila";
import { Aila } from "@oakai/aila";
import type { AilaOptions, AilaPublicChatOptions, Message } from "@oakai/aila";
import { LooseLessonPlan } from "@oakai/aila/src/protocol/schema";
import { handleHeliconeError } from "@oakai/aila/src/utils/moderation/moderationErrorHandling";
import {
TracingSpan,
withTelemetry,
Expand All @@ -16,7 +11,8 @@ import { NextRequest } from "next/server";
import invariant from "tiny-invariant";

import { Config } from "./config";
import { streamingJSON } from "./protocol";
import { handleChatException } from "./errorHandling";
import { fetchAndCheckUser } from "./user";

export const maxDuration = 300;

Expand Down Expand Up @@ -61,25 +57,6 @@ async function setupChatHandler(req: NextRequest) {
);
}

function reportErrorTelemetry(
span: TracingSpan,
error: Error,
errorType: string,
statusMessage: string,
additionalAttributes: Record<
string,
string | number | boolean | undefined
> = {},
) {
span.setTag("error", true);
span.setTag("error.type", errorType);
span.setTag("error.message", statusMessage);
span.setTag("error.stack", error.stack);
Object.entries(additionalAttributes).forEach(([key, value]) => {
span.setTag(key, value);
});
}

function setTelemetryMetadata(
span: TracingSpan,
id: string,
Expand Down Expand Up @@ -110,35 +87,6 @@ function handleConnectionAborted(req: NextRequest) {
return abortController;
}

async function handleThreatDetectionError(
span: TracingSpan,
e: AilaThreatDetectionError,
userId: string,
id: string,
prisma: PrismaClientWithAccelerate,
) {
const heliconeErrorMessage = await handleHeliconeError(userId, id, e, prisma);
reportErrorTelemetry(span, e, "AilaThreatDetectionError", "Threat detected");
return streamingJSON(heliconeErrorMessage);
}

async function handleAilaAuthenticationError(
span: TracingSpan,
e: AilaAuthenticationError,
) {
reportErrorTelemetry(span, e, "AilaAuthenticationError", "Unauthorized");
return new Response("Unauthorized", { status: 401 });
}

async function handleGenericError(span: TracingSpan, e: Error) {
reportErrorTelemetry(span, e, e.name, e.message);
return streamingJSON({
type: "error",
message: e.message,
value: `Sorry, an error occurred: ${e.message}`,
});
}

async function generateChatStream(
aila: Aila,
abortController: AbortController,
Expand All @@ -154,28 +102,6 @@ async function generateChatStream(
);
}

async function handleChatException(
span: TracingSpan,
e: unknown,
userId: string | undefined,
chatId: string,
prisma: PrismaClientWithAccelerate,
): Promise<Response> {
if (e instanceof AilaAuthenticationError) {
return handleAilaAuthenticationError(span, e);
}

if (e instanceof AilaThreatDetectionError && userId) {
return handleThreatDetectionError(span, e, userId, chatId, prisma);
}

if (e instanceof Error) {
return handleGenericError(span, e);
}

throw e;
}

export async function handleChatPostRequest(
req: NextRequest,
config: Config,
Expand All @@ -190,11 +116,7 @@ export async function handleChatPostRequest(
let aila: Aila | undefined;

try {
const userLookup = await config.handleUserLookup(chatId);
// The user lookup can either return a userId or a response like a streaming protocol message
if ("failureResponse" in userLookup) {
return userLookup.failureResponse;
}
userId = await fetchAndCheckUser(chatId);

span.setTag("user_id", userId);
aila = await withTelemetry(
Expand Down
13 changes: 0 additions & 13 deletions apps/nextjs/src/app/api/chat/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,14 @@ import {
} from "@oakai/db";
import { nanoid } from "ai";

import { handleUserLookup as defaultHandleUserLookup } from "./user";
import { createWebActionsPlugin } from "./webActionsPlugin";

export interface Config {
shouldPerformUserLookup: boolean;
mockUserId?: string;
handleUserLookup: (chatId: string) => Promise<
| {
userId: string;
}
| {
failureResponse: Response;
}
>;
prisma: PrismaClientWithAccelerate;
createAila: (options: Partial<AilaInitializationOptions>) => Promise<Aila>;
}

export const defaultConfig: Config = {
shouldPerformUserLookup: true,
handleUserLookup: defaultHandleUserLookup,
prisma: globalPrisma,
createAila: async (options) => {
const webActionsPlugin = createWebActionsPlugin(globalPrisma);
Expand Down
122 changes: 122 additions & 0 deletions apps/nextjs/src/app/api/chat/errorHandling.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import { AilaAuthenticationError, AilaThreatDetectionError } from "@oakai/aila";
import * as moderationErrorHandling from "@oakai/aila/src/utils/moderation/moderationErrorHandling";
import { UserBannedError } from "@oakai/core/src/models/safetyViolations";
import { TracingSpan } from "@oakai/core/src/tracing/serverTracing";
import { RateLimitExceededError } from "@oakai/core/src/utils/rateLimiting/userBasedRateLimiter";
import { PrismaClientWithAccelerate } from "@oakai/db";

import { consumeStream } from "@/utils/testHelpers/consumeStream";

import { handleChatException } from "./errorHandling";

describe("handleChatException", () => {
describe("AilaThreatDetectionError", () => {
it("should forward the message from handleHeliconeError", async () => {
jest
.spyOn(moderationErrorHandling, "handleHeliconeError")
.mockResolvedValue({
type: "error",
value: "Threat detected",
message: "Threat was detected",
});

const span = { setTag: jest.fn() } as unknown as TracingSpan;
const error = new AilaThreatDetectionError("test error");
const prisma = {} as unknown as PrismaClientWithAccelerate;

const response = await handleChatException(
span,
error,
"test-user-id",
"test-chat-id",
prisma,
);

expect(response.status).toBe(200);

const message = JSON.parse(
await consumeStream(response.body as ReadableStream),
);
expect(message).toEqual({
type: "error",
value: "Threat detected",
message: "Threat was detected",
});
});
});

describe("AilaAuthenticationError", () => {
it("should return an error chat message", async () => {
const span = { setTag: jest.fn() } as unknown as TracingSpan;
const error = new AilaAuthenticationError("test error");
const prisma = {} as unknown as PrismaClientWithAccelerate;

const response = await handleChatException(
span,
error,
"test-user-id",
"test-chat-id",
prisma,
);

expect(response.status).toBe(401);

const message = await consumeStream(response.body as ReadableStream);
expect(message).toEqual("Unauthorized");
});
});

describe("RateLimitExceededError", () => {
it("should return an error chat message", async () => {
const span = { setTag: jest.fn() } as unknown as TracingSpan;
const error = new RateLimitExceededError(100, Date.now() + 3600 * 1000);
const prisma = {} as unknown as PrismaClientWithAccelerate;

const response = await handleChatException(
span,
error,
"test-user-id",
"test-chat-id",
prisma,
);

expect(response.status).toBe(200);

const message = JSON.parse(
await consumeStream(response.body as ReadableStream),
);
expect(message).toEqual({
type: "error",
value: "Rate limit exceeded",
message:
"**Unfortunately you’ve exceeded your fair usage limit for today.** Please come back in 1 hour. If you require a higher limit, please [make a request](https://forms.gle/tHsYMZJR367zydsG8).",
});
});
});

describe("UserBannedError", () => {
it("should return an error chat message", async () => {
const span = { setTag: jest.fn() } as unknown as TracingSpan;
const error = new UserBannedError("test error");
const prisma = {} as unknown as PrismaClientWithAccelerate;

const response = await handleChatException(
span,
error,
"test-user-id",
"test-chat-id",
prisma,
);

expect(response.status).toBe(200);

const message = JSON.parse(
await consumeStream(response.body as ReadableStream),
);
expect(message).toEqual({
type: "action",
action: "SHOW_ACCOUNT_LOCKED",
});
});
});
});
115 changes: 115 additions & 0 deletions apps/nextjs/src/app/api/chat/errorHandling.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import { AilaAuthenticationError, AilaThreatDetectionError } from "@oakai/aila";
import {
ActionDocument,
ErrorDocument,
} from "@oakai/aila/src/protocol/jsonPatchProtocol";
import { handleHeliconeError } from "@oakai/aila/src/utils/moderation/moderationErrorHandling";
import { UserBannedError } from "@oakai/core/src/models/safetyViolations";
import { TracingSpan } from "@oakai/core/src/tracing/serverTracing";
import { RateLimitExceededError } from "@oakai/core/src/utils/rateLimiting/userBasedRateLimiter";
import { PrismaClientWithAccelerate } from "@oakai/db";

import { streamingJSON } from "./protocol";

function reportErrorTelemetry(
span: TracingSpan,
error: Error,
errorType: string,
statusMessage: string,
additionalAttributes: Record<
string,
string | number | boolean | undefined
> = {},
) {
span.setTag("error", true);
span.setTag("error.type", errorType);
span.setTag("error.message", statusMessage);
span.setTag("error.stack", error.stack);
Object.entries(additionalAttributes).forEach(([key, value]) => {
span.setTag(key, value);
});
}

async function handleThreatDetectionError(
span: TracingSpan,
e: AilaThreatDetectionError,
userId: string,
id: string,
prisma: PrismaClientWithAccelerate,
) {
const heliconeErrorMessage = await handleHeliconeError(userId, id, e, prisma);
reportErrorTelemetry(span, e, "AilaThreatDetectionError", "Threat detected");
return streamingJSON(heliconeErrorMessage);
}

async function handleAilaAuthenticationError(
span: TracingSpan,
e: AilaAuthenticationError,
) {
reportErrorTelemetry(span, e, "AilaAuthenticationError", "Unauthorized");
return new Response("Unauthorized", { status: 401 });
}

export async function handleRateLimitError(
span: TracingSpan,
error: RateLimitExceededError,
) {
reportErrorTelemetry(span, error, "RateLimitExceededError", "Rate limited");

const timeRemainingHours = Math.ceil(
(error.reset - Date.now()) / 1000 / 60 / 60,
);
const hours = timeRemainingHours === 1 ? "hour" : "hours";

return streamingJSON({
type: "error",
value: error.message,
message: `**Unfortunately you’ve exceeded your fair usage limit for today.** Please come back in ${timeRemainingHours} ${hours}. If you require a higher limit, please [make a request](${process.env.RATELIMIT_FORM_URL}).`,
} as ErrorDocument);
}

async function handleUserBannedError() {
return streamingJSON({
type: "action",
action: "SHOW_ACCOUNT_LOCKED",
} as ActionDocument);
}

async function handleGenericError(span: TracingSpan, e: Error) {
reportErrorTelemetry(span, e, e.name, e.message);
return streamingJSON({
type: "error",
message: e.message,
value: `Sorry, an error occurred: ${e.message}`,
} as ErrorDocument);
}

export async function handleChatException(
span: TracingSpan,
e: unknown,
userId: string | undefined,
chatId: string,
prisma: PrismaClientWithAccelerate,
): Promise<Response> {
if (e instanceof AilaAuthenticationError) {
return handleAilaAuthenticationError(span, e);
}

if (e instanceof AilaThreatDetectionError && userId) {
return handleThreatDetectionError(span, e, userId, chatId, prisma);
}

if (e instanceof RateLimitExceededError && userId) {
return handleRateLimitError(span, e);
}

if (e instanceof UserBannedError) {
return handleUserBannedError();
}

if (e instanceof Error) {
return handleGenericError(span, e);
}

throw e;
}
Loading

0 comments on commit 90a4964

Please sign in to comment.