From eee2b1efff8aa71e59333b3d0780a826286c34d5 Mon Sep 17 00:00:00 2001 From: Adam Howard <91115+codeincontext@users.noreply.github.com> Date: Fri, 30 Aug 2024 15:26:03 +0200 Subject: [PATCH 1/2] fix: replace Inngest app/generation.requested with an async worker function (#34) --- packages/api/src/router/generations.ts | 8 +- packages/core/index.ts | 2 - packages/core/src/functions/event-types.ts | 2 - .../functions/generation/requestGeneration.ts | 411 ----------------- .../generations}/requestGeneration.schema.ts | 0 .../workers/generations/requestGeneration.ts | 423 ++++++++++++++++++ 6 files changed, 428 insertions(+), 418 deletions(-) delete mode 100644 packages/core/src/functions/generation/requestGeneration.ts rename packages/core/src/{functions/generation => workers/generations}/requestGeneration.schema.ts (100%) create mode 100644 packages/core/src/workers/generations/requestGeneration.ts diff --git a/packages/api/src/router/generations.ts b/packages/api/src/router/generations.ts index dc8a7e1d4..5a415f3e8 100644 --- a/packages/api/src/router/generations.ts +++ b/packages/api/src/router/generations.ts @@ -1,4 +1,4 @@ -import { LessonSummaries, Lessons, Snippets, inngest } from "@oakai/core"; +import { LessonSummaries, Lessons, Snippets } from "@oakai/core"; import { Feedback } from "@oakai/core/src/models/feedback"; import { Generations } from "@oakai/core/src/models/generations"; import { Prompts } from "@oakai/core/src/models/prompts"; @@ -12,9 +12,11 @@ import { generationPartUserTweakedSchema, } from "@oakai/core/src/types"; import { sendQuizFeedbackEmail } from "@oakai/core/src/utils/sendQuizFeedbackEmail"; +import { requestGenerationWorker } from "@oakai/core/src/workers/generations/requestGeneration"; import logger from "@oakai/logger"; import { TRPCError } from "@trpc/server"; import { Redis } from "@upstash/redis"; +import { waitUntil } from "@vercel/functions"; import { uniq } from "remeda"; import { z } from "zod"; @@ -189,8 +191,7 @@ export const generationRouter = router({ ctx.auth.userId, ); - await inngest.send({ - name: "app/generation.requested", + const { pending } = requestGenerationWorker({ data: { appId, promptId, @@ -202,6 +203,7 @@ export const generationRouter = router({ external_id: ctx.auth.userId, }, }); + waitUntil(pending); /** * Track if a generation is a re-generation diff --git a/packages/core/index.ts b/packages/core/index.ts index 49531e5d8..ef3c66798 100644 --- a/packages/core/index.ts +++ b/packages/core/index.ts @@ -1,5 +1,4 @@ import { populateDemoStatuses } from "./src/functions/demo/populateDemoStatuses"; -import { requestGeneration } from "./src/functions/generation/requestGeneration"; import { generateLessonQuizEmbeddings } from "./src/functions/lesson/generateLessonQuizEmbeddings"; import { generatePlanForLesson } from "./src/functions/lesson/generatePlan"; import { summariseLesson } from "./src/functions/lesson/summarise"; @@ -38,7 +37,6 @@ export * from "./src/models"; //export * from "./src/models/promptVariants"; export const functions = [ - requestGeneration, generateTranscriptEmbeddings, generateSnippetEmbeddings, generateQuizQuestionEmbeddings, diff --git a/packages/core/src/functions/event-types.ts b/packages/core/src/functions/event-types.ts index a4e77c9af..c21295aed 100644 --- a/packages/core/src/functions/event-types.ts +++ b/packages/core/src/functions/event-types.ts @@ -2,7 +2,6 @@ import { ZodEventSchemas } from "inngest"; import { z } from "zod"; import { populateDemoStatusesSchema } from "./demo/populateDemoStatuses.schema"; -import { requestGenerationSchema } from "./generation/requestGeneration.schema"; import { generateLessonQuizEmbeddingsSchema } from "./lesson/generateLessonQuizEmbeddings.schema"; import { generatePlanForLessonSchema } from "./lesson/generatePlan.schema"; import { summariseLessonSchema } from "./lesson/summarise.schema"; @@ -37,7 +36,6 @@ import { generateTranscriptEmbeddingsSchema } from "./transcript/generateTranscr const schemas = { "app/healthcheck": { data: z.any() }, - "app/generation.requested": requestGenerationSchema, "app/transcript.embed": generateTranscriptEmbeddingsSchema, "app/snippet.embed": embedSnippetSchema, "app/quizQuestion.embed": embedQuizQuestionSchema, diff --git a/packages/core/src/functions/generation/requestGeneration.ts b/packages/core/src/functions/generation/requestGeneration.ts deleted file mode 100644 index 68aeecd40..000000000 --- a/packages/core/src/functions/generation/requestGeneration.ts +++ /dev/null @@ -1,411 +0,0 @@ -import { GenerationStatus, ModerationType, Prisma, prisma } from "@oakai/db"; -import baseLogger from "@oakai/logger"; -import { Redis } from "@upstash/redis"; -import { NonRetriableError } from "inngest"; - -import { inngest } from "../../client"; -import { createOpenAIModerationsClient } from "../../llm/openai"; -import { SafetyViolations } from "../../models"; -import { Generations } from "../../models/generations"; -import { - CompletionResult, - Json, - LLMCompletionError, - LLMRefusalError, - Prompts, -} from "../../models/prompts"; -import { - checkEnglishLanguageScores, - doOpenAIModeration, - moderationConfig, -} from "../../utils/moderation"; -import { requestGenerationSchema } from "./requestGeneration.schema"; - -const openaiModeration = createOpenAIModerationsClient(); - -const redis = new Redis({ - url: process.env.KV_REST_API_URL as string, - token: process.env.KV_REST_API_TOKEN as string, -}); - -// NOTE: MODERATE_ENGLISH_LANGUAGE isn't in doppler envs anywhere -const MODERATE_LANGUAGE = Boolean(process.env.MODERATE_ENGLISH_LANGUAGE); - -export const requestGeneration = inngest.createFunction( - { - name: "Request AI generation", - id: "app-generation-request", - retries: 0, - onFailure: async ({ error, event, logger }) => { - const generations = new Generations(prisma, logger); - const eventData = event.data.event.data; - - logger.error( - { err: error, generationId: eventData?.generationId }, - "Failed generation (inngest onFailure called), generationId=%s", - eventData?.generationId, - ); - - /** - * Look up the generation, which might not even exist - * if we've ended up in onFailure. If it's in progress, - * attempt to mark it as failed (unless failing it was - * what caused us to get here in the first place!) - */ - const generationRecord = await generations.byId(eventData.generationId); - - if (generationRecord && !generations.isFinished(generationRecord)) { - try { - await generations.failGeneration( - eventData.generationId, - error.message, - ); - } catch (err) { - logger.error({ err }, "Could not mark generation as failed"); - } - } - }, - }, - { event: "app/generation.requested" }, - async ({ event }) => { - baseLogger.info( - `Requesting generation for promptId %s`, - event.data?.promptId ?? "Unknown prompt", - ); - baseLogger.debug({ eventData: event.data }, "Event data for generation"); - - /** - * --------------------- Input validation --------------------- - */ - - const eventData = requestGenerationSchema.data.safeParse(event.data); - const eventUser = requestGenerationSchema.user.safeParse(event.user); - - if (!eventData.success) { - throw new NonRetriableError("event.data failed validation", { - cause: eventData.error, - }); - } else if (!eventUser.success) { - throw new NonRetriableError("event.user failed validation", { - cause: eventUser.error, - }); - } - - const { appId, promptId, generationId, promptInputs, streamCompletion } = - eventData.data; - // const { external_id: userId } = eventUser.data; - - // Create a child logger which has the context of our current generation applied - const logger = baseLogger.child({ - appId, - promptId, - generationId, - }); - - const prompts = new Prompts(prisma, logger); - const generations = new Generations(prisma, logger); - const safetyViolations = new SafetyViolations(prisma, logger); - - logger.info("Running step: Check generation exists"); - const generationRecord = await generations.byId(generationId); - - if (!generationRecord) { - throw new NonRetriableError("Generation does not exist"); - } - - if (generationRecord.status !== GenerationStatus.REQUESTED) { - throw new NonRetriableError("Generation has already been processed"); - } - - await generations.setStatus(generationRecord.id, GenerationStatus.PENDING); - - baseLogger.info("Running step: Lookup prompt"); - const prompt = await prompts.get(promptId, appId); - - if (!prompt) { - throw new NonRetriableError("Prompt does not exist"); - } - - let promptBody: string; - - baseLogger.info("Running step: Format prompt template"); - - try { - promptBody = await prompts.formatPrompt(prompt.template, promptInputs); - - const promptInputsHash = - generations.generatePromptInputsHash(promptInputs); - await generations.update(generationId, { - promptText: promptBody, - promptInputs, - promptInputsHash, - }); - } catch (err) { - logger.error( - err, - "Error formatting or saving prompt template: %s", - err instanceof Error ? err.message : err, - ); - - throw new NonRetriableError( - "Error formatting or saving prompt template", - { - cause: err, - }, - ); - } - - /** - * --------------------- Begin moderation --------------------- - */ - if (moderationConfig.MODERATION_ENABLED === true) { - await generations.setStatus( - generationRecord.id, - GenerationStatus.MODERATING, - ); - - /** - * These are quiz-specific keys which we should ignore, as they're - * almost certainly going to fail checks (e.g. numbers) - * @TODO: Come up with a more robust solution for this - */ - const ignoredInputKeys = [ - "subject", - "ageRange", - "distractorToRegenerate", - "numberOfCorrectAnswers", - "numberOfDistractors", - "knowledge", - "transcript", - "fact", - "sessionId", - ]; - - const userInput = Object.entries(promptInputs) - .filter(([key, value]) => !!value && !ignoredInputKeys.includes(key)) - .map(([, value]) => JSON.stringify(value)); - - logger.info("Running step: Detect languages"); - const [englishScores, isEnglish] = checkEnglishLanguageScores(userInput); - - if (!isEnglish && MODERATE_LANGUAGE) { - logger.info("Running step: Save flagged generation (not english)"); - await generations.flagGeneration( - generationId, - "That looks like non-English text. We currently only support English.", - ModerationType.NOT_ENGLISH, - { - moderationMeta: { englishScores }, - }, - ); - // Return early and don't process generation further - return; - } - - //logger.info("Running step: Detect profanity"); - //const [isProfane, profanity] = detectProfanity(userInput); - - // if (isProfane) { - // logger.info("Running step: Save flagged generation (profanity)"); - // console.log(JSON.stringify(Object.keys(process.env))); - // await generations.flagGeneration( - // generationId, - // `It looks like your input contains profanity ${ - // process.env.DOPPLER_ENVIRONMENT - // } ${ - // process.env.DOPPLER_ENVIRONMENT === "dev" ? ` - ${profanity}` : "." - // }`, - // ModerationType.PROFANITY, - // { - // moderationMeta: { profanity }, - // }, - // ); - - // // Return early and don't process generation further - // return; - // } - - logger.info("Running step: Moderate inputs"); - const { moderationResults, isFlagged, isOverModerationThreshold } = - await doOpenAIModeration(openaiModeration, userInput); - - if (isFlagged) { - logger.info( - "Running step: Save flagged generation (flagged by moderation)", - ); - await generations.flagGeneration( - generationId, - "Inputs were flagged by OpenAI moderation as against their terms of service", - ModerationType.OPENAI_FLAGGED, - { - moderationMeta: { - moderationResults, - } as unknown as Prisma.InputJsonObject, - }, - ); - await safetyViolations.recordViolation( - eventUser.data.external_id, - "QUIZ_GENERATION", - "OPENAI", - "GENERATION", - generationId, - ); - // Return early and don't process generation further - return; - } else { - logger.info("Running step: Save moderation result"); - await generations.update(generationId, { - moderationMeta: { - moderationResults, - } as unknown as Prisma.InputJsonObject, - }); - } - - if (isOverModerationThreshold) { - logger.info( - "Running step: Save flagged generation (over moderation threshold)", - ); - await generations.flagGeneration( - generationId, - "Inputs were flagged by our moderation filter as exceeding our threshold for one or more categories", - ModerationType.OPENAI_OVER_THRESHOLD, - { - moderationMeta: { - moderationResults, - } as unknown as Prisma.InputJsonObject, - }, - ); - - // Return early and don't process generation further - return; - } - } - - /** - * --------------------- Get prompt completion --------------------- - */ - await generations.setStatus( - generationRecord.id, - GenerationStatus.GENERATING, - ); - - logger.info("Running step: Get OpenAI completion"); - const promptRecord = await prompts.get(promptId, appId); - - if (!promptRecord) { - throw new NonRetriableError("Prompt does not exist"); - } - - let completion: CompletionResult | undefined = undefined; - try { - logger.info(`Requesting completion for generationId=%s`, generationId); - - /** - * Stream partial response JSON to redis as the new tokens come in, - * to work around streaming limitations with netlify - */ - const onNewToken = async (partialJson: string) => { - console.log("onNewToken", partialJson); - try { - await redis.set(`partial-generation-${generationId}`, partialJson, { - // Expire after 10 minutes - ex: 60 * 10, - }); - } catch (err) { - console.log("Failed to write to redis"); - logger.error(err, "Error caching generation stream"); - } - }; - - if (process.env.PROMPT_PLAYBACK_ENABLED === "true") { - const priorSuccessfulGeneration = - await generations.getPriorSuccessfulGeneration( - promptId, - promptInputs, - ); - if (priorSuccessfulGeneration) { - const { - llmTimeTaken, - promptTokensUsed, - completionTokensUsed, - response, - } = priorSuccessfulGeneration; - if (response && typeof response === "object") { - logger.info("Request chat completion from prior generation"); - completion = await prompts.requestChatCompletionFromPriorGeneration( - { - timeTaken: llmTimeTaken ?? 0, - promptTokensUsed, - completionTokensUsed, - resultText: JSON.stringify(response), - //TODO potentially casting here is dangerous - result: response as unknown as Json, - }, - ); - } - } - } - - if (!completion) { - logger.info("Request chat completion from LLM"); - completion = await prompts.requestChatCompletion( - promptBody, - streamCompletion ? onNewToken : undefined, - ); - } - } catch (err) { - const errorMessage = - err instanceof Error ? err.message : `Unknown generation error`; - - logger.error(err, errorMessage); - - if (err instanceof LLMRefusalError) { - const { completionMeta } = err; - - await generations.flagGeneration( - generationId, - errorMessage, - ModerationType.LLM_REFUSAL, - { - promptTokensUsed: completionMeta.promptTokensUsed, - completionTokensUsed: completionMeta.completionTokensUsed, - llmTimeTaken: completionMeta.timeTaken, - }, - ); - - return null; - } else if (err instanceof LLMCompletionError) { - const { completionMeta } = err; - - await generations.failGeneration(generationId, errorMessage, { - promptTokensUsed: completionMeta.promptTokensUsed, - completionTokensUsed: completionMeta.completionTokensUsed, - llmTimeTaken: completionMeta.timeTaken, - }); - throw new NonRetriableError(errorMessage); - } else { - throw new NonRetriableError(errorMessage); - } - } - - if (completion?.result) { - logger.info("Running step: Save successful generation"); - await generations.completeGeneration( - generationId, - completion.result satisfies Prisma.InputJsonObject, - { - promptTokensUsed: completion.promptTokensUsed, - completionTokensUsed: completion.completionTokensUsed, - llmTimeTaken: completion.timeTaken, - }, - ); - - logger.info( - `Successfully completed generation, generationId=%s`, - generationId, - ); - } - - return completion; - }, -); diff --git a/packages/core/src/functions/generation/requestGeneration.schema.ts b/packages/core/src/workers/generations/requestGeneration.schema.ts similarity index 100% rename from packages/core/src/functions/generation/requestGeneration.schema.ts rename to packages/core/src/workers/generations/requestGeneration.schema.ts diff --git a/packages/core/src/workers/generations/requestGeneration.ts b/packages/core/src/workers/generations/requestGeneration.ts new file mode 100644 index 000000000..8978f0ff5 --- /dev/null +++ b/packages/core/src/workers/generations/requestGeneration.ts @@ -0,0 +1,423 @@ +import { GenerationStatus, ModerationType, Prisma, prisma } from "@oakai/db"; +import baseLogger, { Logger } from "@oakai/logger"; +import { Redis } from "@upstash/redis"; +import { NonRetriableError } from "inngest"; +import { z } from "zod"; + +import { createOpenAIModerationsClient } from "../../llm/openai"; +import { SafetyViolations } from "../../models"; +import { Generations } from "../../models/generations"; +import { + CompletionResult, + Json, + LLMCompletionError, + LLMRefusalError, + Prompts, +} from "../../models/prompts"; +import { + checkEnglishLanguageScores, + doOpenAIModeration, + moderationConfig, +} from "../../utils/moderation"; +import { requestGenerationSchema } from "./requestGeneration.schema"; + +/** + * Worker converted from an Inngest function + */ + +const openaiModeration = createOpenAIModerationsClient(); + +const redis = new Redis({ + url: process.env.KV_REST_API_URL as string, + token: process.env.KV_REST_API_TOKEN as string, +}); + +// NOTE: MODERATE_ENGLISH_LANGUAGE isn't in doppler envs anywhere +const MODERATE_LANGUAGE = Boolean(process.env.MODERATE_ENGLISH_LANGUAGE); + +type RequestGenerationArgs = { + data: z.infer<(typeof requestGenerationSchema)["data"]>; + user: z.infer<(typeof requestGenerationSchema)["user"]>; +}; + +type WorkerResponse = { + /** + * Async workers perform work in the background. + * On Vercel Edge or Cloudflare workers, you need to explicitly handle the pending Promise like this: + * + * ```ts + * const { pending } = requestGenerationWorker({ ... }); + * context.waitUntil(pending) + * ``` + * + * See `waitUntil` documentation in + * [Cloudflare](https://developers.cloudflare.com/workers/runtime-apis/handlers/fetch/#contextwaituntil) + * and [Vercel](https://vercel.com/docs/functions/edge-middleware/middleware-api#waituntil) + * for more details. + * ``` + */ + pending: Promise; +}; + +export function requestGenerationWorker({ + data, + user, +}: RequestGenerationArgs): WorkerResponse { + const promise = (async () => { + try { + await invoke({ data, user }); + } catch (e) { + await onFailure({ + error: e as Error, + event: { data: { event: { data } } }, + logger: baseLogger, + }); + } + })(); + + return { pending: promise }; +} + +async function invoke({ data, user }: RequestGenerationArgs) { + baseLogger.info( + `Requesting generation for promptId %s`, + data?.promptId ?? "Unknown prompt", + ); + baseLogger.debug({ eventData: data }, "Event data for generation"); + + /** + * --------------------- Input validation --------------------- + */ + + const eventData = requestGenerationSchema.data.safeParse(data); + const eventUser = requestGenerationSchema.user.safeParse(user); + + if (!eventData.success) { + throw new NonRetriableError("event.data failed validation", { + cause: eventData.error, + }); + } else if (!eventUser.success) { + throw new NonRetriableError("event.user failed validation", { + cause: eventUser.error, + }); + } + + const { appId, promptId, generationId, promptInputs, streamCompletion } = + data; + // const { external_id: userId } = eventUser.data; + + // Create a child logger which has the context of our current generation applied + const logger = baseLogger.child({ + appId, + promptId, + generationId, + }); + + const prompts = new Prompts(prisma, logger); + const generations = new Generations(prisma, logger); + const safetyViolations = new SafetyViolations(prisma, logger); + + logger.info("Running step: Check generation exists"); + const generationRecord = await generations.byId(generationId); + + if (!generationRecord) { + throw new NonRetriableError("Generation does not exist"); + } + + if (generationRecord.status !== GenerationStatus.REQUESTED) { + throw new NonRetriableError("Generation has already been processed"); + } + + await generations.setStatus(generationRecord.id, GenerationStatus.PENDING); + + baseLogger.info("Running step: Lookup prompt"); + const prompt = await prompts.get(promptId, appId); + + if (!prompt) { + throw new NonRetriableError("Prompt does not exist"); + } + + let promptBody: string; + + baseLogger.info("Running step: Format prompt template"); + + try { + promptBody = await prompts.formatPrompt(prompt.template, promptInputs); + + const promptInputsHash = generations.generatePromptInputsHash(promptInputs); + await generations.update(generationId, { + promptText: promptBody, + promptInputs, + promptInputsHash, + }); + } catch (err) { + logger.error( + err, + "Error formatting or saving prompt template: %s", + err instanceof Error ? err.message : err, + ); + + throw new NonRetriableError("Error formatting or saving prompt template", { + cause: err, + }); + } + + /** + * --------------------- Begin moderation --------------------- + */ + if (moderationConfig.MODERATION_ENABLED === true) { + await generations.setStatus( + generationRecord.id, + GenerationStatus.MODERATING, + ); + + /** + * These are quiz-specific keys which we should ignore, as they're + * almost certainly going to fail checks (e.g. numbers) + * @TODO: Come up with a more robust solution for this + */ + const ignoredInputKeys = [ + "subject", + "ageRange", + "distractorToRegenerate", + "numberOfCorrectAnswers", + "numberOfDistractors", + "knowledge", + "transcript", + "fact", + "sessionId", + ]; + + const userInput = Object.entries(promptInputs) + .filter(([key, value]) => !!value && !ignoredInputKeys.includes(key)) + .map(([, value]) => JSON.stringify(value)); + + logger.info("Running step: Detect languages"); + const [englishScores, isEnglish] = checkEnglishLanguageScores(userInput); + + if (!isEnglish && MODERATE_LANGUAGE) { + logger.info("Running step: Save flagged generation (not english)"); + await generations.flagGeneration( + generationId, + "That looks like non-English text. We currently only support English.", + ModerationType.NOT_ENGLISH, + { + moderationMeta: { englishScores }, + }, + ); + // Return early and don't process generation further + return; + } + + logger.info("Running step: Moderate inputs"); + const { moderationResults, isFlagged, isOverModerationThreshold } = + await doOpenAIModeration(openaiModeration, userInput); + + if (isFlagged) { + logger.info( + "Running step: Save flagged generation (flagged by moderation)", + ); + await generations.flagGeneration( + generationId, + "Inputs were flagged by OpenAI moderation as against their terms of service", + ModerationType.OPENAI_FLAGGED, + { + moderationMeta: { + moderationResults, + } as unknown as Prisma.InputJsonObject, + }, + ); + await safetyViolations.recordViolation( + eventUser.data.external_id, + "QUIZ_GENERATION", + "OPENAI", + "GENERATION", + generationId, + ); + // Return early and don't process generation further + return; + } else { + logger.info("Running step: Save moderation result"); + await generations.update(generationId, { + moderationMeta: { + moderationResults, + } as unknown as Prisma.InputJsonObject, + }); + } + + if (isOverModerationThreshold) { + logger.info( + "Running step: Save flagged generation (over moderation threshold)", + ); + await generations.flagGeneration( + generationId, + "Inputs were flagged by our moderation filter as exceeding our threshold for one or more categories", + ModerationType.OPENAI_OVER_THRESHOLD, + { + moderationMeta: { + moderationResults, + } as unknown as Prisma.InputJsonObject, + }, + ); + + // Return early and don't process generation further + return; + } + } + + /** + * --------------------- Get prompt completion --------------------- + */ + await generations.setStatus(generationRecord.id, GenerationStatus.GENERATING); + + logger.info("Running step: Get OpenAI completion"); + const promptRecord = await prompts.get(promptId, appId); + + if (!promptRecord) { + throw new NonRetriableError("Prompt does not exist"); + } + + let completion: CompletionResult | undefined = undefined; + try { + logger.info(`Requesting completion for generationId=%s`, generationId); + + /** + * Stream partial response JSON to redis as the new tokens come in, + * to work around streaming limitations with netlify + */ + const onNewToken = async (partialJson: string) => { + console.log("onNewToken", partialJson); + try { + await redis.set(`partial-generation-${generationId}`, partialJson, { + // Expire after 10 minutes + ex: 60 * 10, + }); + } catch (err) { + console.log("Failed to write to redis"); + logger.error(err, "Error caching generation stream"); + } + }; + + if (process.env.PROMPT_PLAYBACK_ENABLED === "true") { + const priorSuccessfulGeneration = + await generations.getPriorSuccessfulGeneration(promptId, promptInputs); + if (priorSuccessfulGeneration) { + const { + llmTimeTaken, + promptTokensUsed, + completionTokensUsed, + response, + } = priorSuccessfulGeneration; + if (response && typeof response === "object") { + logger.info("Request chat completion from prior generation"); + completion = await prompts.requestChatCompletionFromPriorGeneration({ + timeTaken: llmTimeTaken ?? 0, + promptTokensUsed, + completionTokensUsed, + resultText: JSON.stringify(response), + //TODO potentially casting here is dangerous + result: response as unknown as Json, + }); + } + } + } + + if (!completion) { + logger.info("Request chat completion from LLM"); + completion = await prompts.requestChatCompletion( + promptBody, + streamCompletion ? onNewToken : undefined, + ); + } + } catch (err) { + const errorMessage = + err instanceof Error ? err.message : `Unknown generation error`; + + logger.error(err, errorMessage); + + if (err instanceof LLMRefusalError) { + const { completionMeta } = err; + + await generations.flagGeneration( + generationId, + errorMessage, + ModerationType.LLM_REFUSAL, + { + promptTokensUsed: completionMeta.promptTokensUsed, + completionTokensUsed: completionMeta.completionTokensUsed, + llmTimeTaken: completionMeta.timeTaken, + }, + ); + + return null; + } else if (err instanceof LLMCompletionError) { + const { completionMeta } = err; + + await generations.failGeneration(generationId, errorMessage, { + promptTokensUsed: completionMeta.promptTokensUsed, + completionTokensUsed: completionMeta.completionTokensUsed, + llmTimeTaken: completionMeta.timeTaken, + }); + throw new NonRetriableError(errorMessage); + } else { + throw new NonRetriableError(errorMessage); + } + } + + if (completion?.result) { + logger.info("Running step: Save successful generation"); + await generations.completeGeneration( + generationId, + completion.result satisfies Prisma.InputJsonObject, + { + promptTokensUsed: completion.promptTokensUsed, + completionTokensUsed: completion.completionTokensUsed, + llmTimeTaken: completion.timeTaken, + }, + ); + + logger.info( + `Successfully completed generation, generationId=%s`, + generationId, + ); + } + + return completion; +} + +type OnFailureArgs = { + error: Error; + event: { + data: { + event: { data: z.infer<(typeof requestGenerationSchema)["data"]> }; + }; + }; + logger: Logger; +}; + +async function onFailure({ error, event, logger }: OnFailureArgs) { + const generations = new Generations(prisma, logger); + const eventData = event.data.event.data; + + logger.error( + { err: error, generationId: eventData?.generationId }, + "Failed generation (inngest onFailure called), generationId=%s", + eventData?.generationId, + ); + + /** + * Look up the generation, which might not even exist + * if we've ended up in onFailure. If it's in progress, + * attempt to mark it as failed (unless failing it was + * what caused us to get here in the first place!) + */ + const generationRecord = await generations.byId(eventData.generationId); + + if (generationRecord && !generations.isFinished(generationRecord)) { + try { + await generations.failGeneration(eventData.generationId, error.message); + } catch (err) { + logger.error({ err }, "Could not mark generation as failed"); + } + } +} From 82fe63b9fdbaee4270757863c38b3cee8f8f1add Mon Sep 17 00:00:00 2001 From: semantic-release-bot Date: Fri, 30 Aug 2024 13:26:48 +0000 Subject: [PATCH 2/2] build(release v1.2.2): See CHANGE_LOG.md --- CHANGE_LOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGE_LOG.md b/CHANGE_LOG.md index bf9f4b5cb..3702420c8 100644 --- a/CHANGE_LOG.md +++ b/CHANGE_LOG.md @@ -1,3 +1,10 @@ +## [1.2.2](https://github.com/oaknational/oak-ai-lesson-assistant/compare/v1.2.1...v1.2.2) (2024-08-30) + + +### Bug Fixes + +* replace Inngest app/generation.requested with an async worker function ([#34](https://github.com/oaknational/oak-ai-lesson-assistant/issues/34)) ([eee2b1e](https://github.com/oaknational/oak-ai-lesson-assistant/commit/eee2b1efff8aa71e59333b3d0780a826286c34d5)) + ## [1.2.1](https://github.com/oaknational/oak-ai-lesson-assistant/compare/v1.2.0...v1.2.1) (2024-08-29)