From 5c5f1b3e18927eda694cade0c0934845e842be1e Mon Sep 17 00:00:00 2001 From: Stef Lewandowski Date: Wed, 28 Aug 2024 20:19:02 +0100 Subject: [PATCH] feat: aila categoriser feature with chat ID and user ID (#12) --- packages/aila/src/core/Aila.test.ts | 34 +++++++++ packages/aila/src/core/Aila.ts | 60 ++++++---------- packages/aila/src/core/AilaServices.ts | 1 + packages/aila/src/core/lesson/AilaLesson.ts | 41 ++++++++++- .../builders/AilaLessonPromptBuilder.ts | 4 +- packages/aila/src/core/types.ts | 4 ++ .../categorisers/AilaCategorisation.ts | 70 +++++++++++++++++++ .../categorisers/MockCategoriser.ts | 16 +++++ .../aila/src/features/categorisation/index.ts | 1 + packages/aila/src/features/rag/AilaRag.ts | 9 ++- packages/aila/src/features/types.ts | 7 ++ .../utils/lessonPlan/fetchCategorisedInput.ts | 42 ----------- .../aila/src/utils/rag/fetchRagContent.ts | 6 +- packages/core/src/models/lessonPlans.ts | 2 +- packages/core/src/models/lessonSummaries.ts | 2 +- packages/core/src/rag/index.ts | 21 ++++-- 16 files changed, 225 insertions(+), 95 deletions(-) create mode 100644 packages/aila/src/features/categorisation/categorisers/AilaCategorisation.ts create mode 100644 packages/aila/src/features/categorisation/categorisers/MockCategoriser.ts create mode 100644 packages/aila/src/features/categorisation/index.ts delete mode 100644 packages/aila/src/utils/lessonPlan/fetchCategorisedInput.ts diff --git a/packages/aila/src/core/Aila.test.ts b/packages/aila/src/core/Aila.test.ts index 801bd7b2d..1d9bd8cea 100644 --- a/packages/aila/src/core/Aila.test.ts +++ b/packages/aila/src/core/Aila.test.ts @@ -1,6 +1,7 @@ import { Aila } from "."; import { MockLLMService } from "../../tests/mocks/MockLLMService"; import { setupPolly } from "../../tests/mocks/setupPolly"; +import { MockCategoriser } from "../features/categorisation/categorisers/MockCategoriser"; import { AilaAuthenticationError } from "./AilaError"; describe("Aila", () => { @@ -285,4 +286,37 @@ describe("Aila", () => { expect(ailaInstance.lesson.plan.title).toBe(newTitle); }, 20000); }); + + describe("categorisation", () => { + it("should use the provided MockCategoriser", async () => { + const mockedLessonPlan = { + title: "Mocked Lesson Plan", + subject: "Mocked Subject", + keyStage: "key-stage-3", + }; + + const mockCategoriser = new MockCategoriser({ mockedLessonPlan }); + + const ailaInstance = new Aila({ + lessonPlan: {}, + chat: { id: "123", userId: "user123" }, + options: { + usePersistence: false, + useRag: false, + useAnalytics: false, + useModeration: false, + }, + services: { + chatCategoriser: mockCategoriser, + }, + plugins: [], + }); + + await ailaInstance.initialise(); + + expect(ailaInstance.lesson.plan.title).toBe("Mocked Lesson Plan"); + expect(ailaInstance.lesson.plan.subject).toBe("Mocked Subject"); + expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-3"); + }); + }); }); diff --git a/packages/aila/src/core/Aila.ts b/packages/aila/src/core/Aila.ts index 9de7541c3..9c51cba16 100644 --- a/packages/aila/src/core/Aila.ts +++ b/packages/aila/src/core/Aila.ts @@ -6,6 +6,7 @@ import { DEFAULT_TEMPERATURE, DEFAULT_RAG_LESSON_PLANS, } from "../constants"; +import { AilaCategorisation } from "../features/categorisation"; import { AilaAnalyticsFeature, AilaErrorReportingFeature, @@ -13,7 +14,6 @@ import { AilaPersistenceFeature, AilaThreatDetectionFeature, } from "../features/types"; -import { fetchCategorisedInput } from "../utils/lessonPlan/fetchCategorisedInput"; import { AilaAuthenticationError, AilaGenerationError } from "./AilaError"; import { AilaFeatureFactory } from "./AilaFeatureFactory"; import { @@ -43,8 +43,12 @@ export class Aila implements AilaServices { private _threatDetection?: AilaThreatDetectionFeature; private _prisma: PrismaClientWithAccelerate; private _plugins: AilaPlugin[]; + private _userId!: string | undefined; + private _chatId!: string; constructor(options: AilaInitializationOptions) { + this._userId = options.chat.userId; + this._chatId = options.chat.id; this._options = this.initialiseOptions(options.options); this._chat = new AilaChat({ @@ -53,9 +57,21 @@ export class Aila implements AilaServices { promptBuilder: options.promptBuilder, }); - this._lesson = new AilaLesson({ lessonPlan: options.lessonPlan ?? {} }); this._prisma = options.prisma ?? globalPrisma; + this._lesson = new AilaLesson({ + aila: this, + lessonPlan: options.lessonPlan ?? {}, + categoriser: + options.services?.chatCategoriser ?? + new AilaCategorisation({ + aila: this, + prisma: this._prisma, + chatId: this._chatId, + userId: this._userId, + }), + }); + this._analytics = AilaFeatureFactory.createAnalytics(this, this._options); this._moderation = AilaFeatureFactory.createModeration(this, this._options); this._persistence = AilaFeatureFactory.createPersistence( @@ -81,7 +97,7 @@ export class Aila implements AilaServices { // Initialization methods public async initialise() { this.checkUserIdPresentIfPersisting(); - await this.setUpInitialLessonPlan(); + await this._lesson.setUpInitialLessonPlan(this._chat.messages); } private initialiseOptions(options?: AilaOptions) { @@ -128,11 +144,11 @@ export class Aila implements AilaServices { } public get chatId() { - return this._chat.id; + return this._chatId; } public get userId() { - return this._chat.userId; + return this._userId; } public get messages() { @@ -168,40 +184,6 @@ export class Aila implements AilaServices { } } - // Setup methods - - // #TODO this is in the wrong place and should be - // moved to be hook into the initialisation of the lesson - // or chat - public async setUpInitialLessonPlan() { - const shouldRequestInitialState = Boolean( - !this.lesson.plan.subject && - !this.lesson.plan.keyStage && - !this.lesson.plan.title, - ); - - if (shouldRequestInitialState) { - const { title, subject, keyStage, topic } = this.lesson.plan; - const input = this.chat.messages.map((i) => i.content).join("\n\n"); - const categorisationInput = [title, subject, keyStage, topic, input] - .filter((i) => i) - .join(" "); - - const result = await fetchCategorisedInput({ - input: categorisationInput, - prisma: this._prisma, - chatMeta: { - userId: this._chat.userId, - chatId: this._chat.id, - }, - }); - - if (result) { - this.lesson.initialise(result); - } - } - } - // Generation methods public async generateSync(opts: AilaGenerateLessonPlanOptions) { const stream = await this.generate(opts); diff --git a/packages/aila/src/core/AilaServices.ts b/packages/aila/src/core/AilaServices.ts index a954c1214..cb79a06d2 100644 --- a/packages/aila/src/core/AilaServices.ts +++ b/packages/aila/src/core/AilaServices.ts @@ -23,6 +23,7 @@ export interface AilaLessonService { readonly hasSetInitialState: boolean; applyPatches(patches: string): void; initialise(plan: LooseLessonPlan): void; + setUpInitialLessonPlan(messages: Message[]): Promise; } export interface AilaChatService { diff --git a/packages/aila/src/core/lesson/AilaLesson.ts b/packages/aila/src/core/lesson/AilaLesson.ts index cfa4004c2..a64e8bcc0 100644 --- a/packages/aila/src/core/lesson/AilaLesson.ts +++ b/packages/aila/src/core/lesson/AilaLesson.ts @@ -1,21 +1,42 @@ import { deepClone } from "fast-json-patch"; +import { AilaCategorisation } from "../../features/categorisation/categorisers/AilaCategorisation"; +import { AilaCategorisationFeature } from "../../features/types"; import { PatchDocument, applyLessonPlanPatch, extractPatches, } from "../../protocol/jsonPatchProtocol"; import { LooseLessonPlan } from "../../protocol/schema"; -import { AilaLessonService } from "../AilaServices"; +import { AilaLessonService, AilaServices } from "../AilaServices"; +import { Message } from "../chat"; export class AilaLesson implements AilaLessonService { + private _aila: AilaServices; private _plan: LooseLessonPlan; private _hasSetInitialState = false; private _appliedPatches: PatchDocument[] = []; private _invalidPatches: PatchDocument[] = []; + private _categoriser: AilaCategorisationFeature; - constructor({ lessonPlan }: { lessonPlan?: LooseLessonPlan }) { + constructor({ + aila, + lessonPlan, + categoriser, + }: { + aila: AilaServices; + lessonPlan?: LooseLessonPlan; + categoriser?: AilaCategorisationFeature; + }) { + this._aila = aila; this._plan = lessonPlan ?? {}; + this._categoriser = + categoriser ?? + new AilaCategorisation({ + aila, + userId: aila.userId, + chatId: aila.chatId, + }); } public get plan(): LooseLessonPlan { @@ -74,4 +95,20 @@ export class AilaLesson implements AilaLessonService { this._plan = workingLessonPlan; } + + public async setUpInitialLessonPlan(messages: Message[]) { + const shouldCategoriseBasedOnInitialMessages = Boolean( + !this._plan.subject && !this._plan.keyStage && !this._plan.title, + ); + + // The initial lesson plan is blank, so we take the first messages + // and attempt to deduce the lesson plan key stage, subject, title and topic + if (shouldCategoriseBasedOnInitialMessages) { + const result = await this._categoriser.categorise(messages, this._plan); + + if (result) { + this.initialise(result); + } + } + } } diff --git a/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts b/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts index 371b51116..ef13b0ae0 100644 --- a/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts +++ b/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts @@ -42,7 +42,7 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder { private async fetchRelevantLessonPlans(): Promise { const noRelevantLessonPlans = "None"; - const chatId = this._aila?.chatId; + const { chatId, userId } = this._aila; if (!this._aila?.options.useRag) { return noRelevantLessonPlans; } @@ -63,6 +63,8 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder { this._aila?.options.numberOfLessonPlansInRag ?? DEFAULT_RAG_LESSON_PLANS, prisma: globalPrisma, + chatId, + userId, }); }, "Did not fetch RAG content. Continuing"); diff --git a/packages/aila/src/core/types.ts b/packages/aila/src/core/types.ts index aa7081da1..5f6ed42d5 100644 --- a/packages/aila/src/core/types.ts +++ b/packages/aila/src/core/types.ts @@ -5,6 +5,7 @@ import { AilaPersistence } from "../features/persistence"; import { AilaThreatDetector } from "../features/threatDetection"; import { AilaAnalyticsFeature, + AilaCategorisationFeature, AilaErrorReportingFeature, AilaModerationFeature, AilaThreatDetectionFeature, @@ -67,4 +68,7 @@ export type AilaInitializationOptions = { errorReporter?: AilaErrorReportingFeature; promptBuilder?: AilaPromptBuilder; plugins: AilaPlugin[]; + services?: { + chatCategoriser?: AilaCategorisationFeature; + }; }; diff --git a/packages/aila/src/features/categorisation/categorisers/AilaCategorisation.ts b/packages/aila/src/features/categorisation/categorisers/AilaCategorisation.ts new file mode 100644 index 000000000..4b098c2ce --- /dev/null +++ b/packages/aila/src/features/categorisation/categorisers/AilaCategorisation.ts @@ -0,0 +1,70 @@ +import { RAG } from "@oakai/core/src/rag"; +import { + type PrismaClientWithAccelerate, + prisma as globalPrisma, +} from "@oakai/db"; + +import { AilaServices, Message } from "../../../core"; +import { LooseLessonPlan } from "../../../protocol/schema"; +import { AilaCategorisationFeature } from "../../types"; + +export class AilaCategorisation implements AilaCategorisationFeature { + private _aila: AilaServices; + private _prisma: PrismaClientWithAccelerate; + private _chatId: string; + private _userId: string | undefined; + constructor({ + aila, + prisma, + chatId, + userId, + }: { + aila: AilaServices; + prisma?: PrismaClientWithAccelerate; + chatId: string; + userId?: string; + }) { + this._aila = aila; + this._prisma = prisma ?? globalPrisma; + this._chatId = chatId; + this._userId = userId; + } + public async categorise( + messages: Message[], + lessonPlan: LooseLessonPlan, + ): Promise { + const { title, subject, keyStage, topic } = lessonPlan; + const input = messages.map((i) => i.content).join("\n\n"); + const categorisationInput = [title, subject, keyStage, topic, input] + .filter((i) => i) + .join(" "); + + const result = await this.fetchCategorisedInput( + categorisationInput, + this._prisma, + ); + return result; + } + + private async fetchCategorisedInput( + input: string, + prisma: PrismaClientWithAccelerate, + ): Promise { + const rag = new RAG(prisma, { + chatId: this._chatId, + userId: this._userId, + }); + const parsedCategorisation = await rag.categoriseKeyStageAndSubject(input, { + chatId: this._chatId, + userId: this._userId, + }); + const { keyStage, subject, title, topic } = parsedCategorisation; + const plan: LooseLessonPlan = { + keyStage: keyStage ?? undefined, + subject: subject ?? undefined, + title: title ?? undefined, + topic: topic ?? undefined, + }; + return plan; + } +} diff --git a/packages/aila/src/features/categorisation/categorisers/MockCategoriser.ts b/packages/aila/src/features/categorisation/categorisers/MockCategoriser.ts new file mode 100644 index 000000000..ab51ba369 --- /dev/null +++ b/packages/aila/src/features/categorisation/categorisers/MockCategoriser.ts @@ -0,0 +1,16 @@ +import { LooseLessonPlan } from "../../../protocol/schema"; +import { AilaCategorisationFeature } from "../../types"; + +export class MockCategoriser implements AilaCategorisationFeature { + private _mockedLessonPlan: LooseLessonPlan | undefined; + constructor({ + mockedLessonPlan, + }: { + mockedLessonPlan: LooseLessonPlan | undefined; + }) { + this._mockedLessonPlan = mockedLessonPlan; + } + public async categorise(): Promise { + return this._mockedLessonPlan; + } +} diff --git a/packages/aila/src/features/categorisation/index.ts b/packages/aila/src/features/categorisation/index.ts new file mode 100644 index 000000000..8dc53a87a --- /dev/null +++ b/packages/aila/src/features/categorisation/index.ts @@ -0,0 +1 @@ +export { AilaCategorisation } from "./categorisers/AilaCategorisation"; diff --git a/packages/aila/src/features/rag/AilaRag.ts b/packages/aila/src/features/rag/AilaRag.ts index 5d64b0fea..0756e913e 100644 --- a/packages/aila/src/features/rag/AilaRag.ts +++ b/packages/aila/src/features/rag/AilaRag.ts @@ -8,7 +8,7 @@ import { LooseLessonPlan } from "../../protocol/schema"; import { minifyLessonPlanForRelevantLessons } from "../../utils/lessonPlan/minifyLessonPlanForRelevantLessons"; export class AilaRag { - private _aila?: AilaServices; + private _aila: AilaServices; private _rag: RAG; private _prisma: PrismaClientWithAccelerate; @@ -16,12 +16,15 @@ export class AilaRag { aila, prisma, }: { - aila?: AilaServices; + aila: AilaServices; prisma?: PrismaClientWithAccelerate; }) { this._aila = aila; this._prisma = prisma ?? globalPrisma; - this._rag = new RAG(this._prisma); + this._rag = new RAG(this._prisma, { + userId: aila.userId, + chatId: aila.chatId, + }); } public async fetchRagContent({ diff --git a/packages/aila/src/features/types.ts b/packages/aila/src/features/types.ts index de1876e56..18b08d101 100644 --- a/packages/aila/src/features/types.ts +++ b/packages/aila/src/features/types.ts @@ -53,3 +53,10 @@ export interface AilaErrorReportingFeature { breadcrumbs?: { category: string; message: string }, ): T | null; } + +export interface AilaCategorisationFeature { + categorise( + messages: Message[], + lessonPlan: LooseLessonPlan, + ): Promise; +} diff --git a/packages/aila/src/utils/lessonPlan/fetchCategorisedInput.ts b/packages/aila/src/utils/lessonPlan/fetchCategorisedInput.ts deleted file mode 100644 index 65dd97901..000000000 --- a/packages/aila/src/utils/lessonPlan/fetchCategorisedInput.ts +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Fetches the categorised key stage and subject from the RAG API. - * @param input The input to categorise. - * @returns The categorised key stage and subject. - * @throws {Error} If the categorisation fails. - * @example - * const input = "This is a lesson plan about algebra for KS3 students."; - * const categorised = await fetchCategorisedInput(input); - * console.log(categorised); - * // Output: { keyStage: "KS3", subject: "Maths", title: "Algebra" } - */ -import { RAG } from "@oakai/core/src/rag"; -import { PrismaClientWithAccelerate } from "@oakai/db"; - -import { LooseLessonPlan } from "../../protocol/schema"; - -export async function fetchCategorisedInput({ - input, - prisma, - chatMeta, -}: { - input: string; - prisma: PrismaClientWithAccelerate; - chatMeta: { - userId: string | undefined; - chatId: string; - }; -}): Promise { - const rag = new RAG(prisma); - const parsedCategorisation = await rag.categoriseKeyStageAndSubject( - input, - chatMeta, - ); - const { keyStage, subject, title, topic } = parsedCategorisation; - const plan: LooseLessonPlan = { - keyStage: keyStage ?? undefined, - subject: subject ?? undefined, - title: title ?? undefined, - topic: topic ?? undefined, - }; - return plan; -} diff --git a/packages/aila/src/utils/rag/fetchRagContent.ts b/packages/aila/src/utils/rag/fetchRagContent.ts index 2ad8b3f9f..1c66a6c7a 100644 --- a/packages/aila/src/utils/rag/fetchRagContent.ts +++ b/packages/aila/src/utils/rag/fetchRagContent.ts @@ -12,6 +12,8 @@ export async function fetchRagContent({ id, k = 5, prisma, + chatId, + userId, }: { title: string; subject?: string; @@ -20,10 +22,12 @@ export async function fetchRagContent({ id: string; k: number; prisma: PrismaClientWithAccelerate; + chatId: string; + userId?: string; }) { let content = "[]"; - const rag = new RAG(prisma); + const rag = new RAG(prisma, { chatId, userId }); const ragLessonPlans = await tryWithErrorReporting( () => { return title && keyStage && subject diff --git a/packages/core/src/models/lessonPlans.ts b/packages/core/src/models/lessonPlans.ts index 772078e1f..bfabf076b 100644 --- a/packages/core/src/models/lessonPlans.ts +++ b/packages/core/src/models/lessonPlans.ts @@ -67,7 +67,7 @@ export class LessonPlans { private _prisma: PrismaClientWithAccelerate; constructor(private readonly prisma: PrismaClientWithAccelerate) { this._prisma = prisma; - this._rag = new RAG(this._prisma); + this._rag = new RAG(this._prisma, { chatId: "none" }); } async embedAllParts(): Promise { diff --git a/packages/core/src/models/lessonSummaries.ts b/packages/core/src/models/lessonSummaries.ts index 2d51e8eff..104f7be50 100644 --- a/packages/core/src/models/lessonSummaries.ts +++ b/packages/core/src/models/lessonSummaries.ts @@ -28,7 +28,7 @@ export class LessonSummaries { private _prisma: PrismaClientWithAccelerate; constructor(private readonly prisma: PrismaClientWithAccelerate) { this._prisma = prisma; - this._rag = new RAG(this._prisma); + this._rag = new RAG(this._prisma, { chatId: "none" }); } async embedAll(): Promise { diff --git a/packages/core/src/rag/index.ts b/packages/core/src/rag/index.ts index 4bbbc4c76..834ddbd44 100644 --- a/packages/core/src/rag/index.ts +++ b/packages/core/src/rag/index.ts @@ -66,13 +66,18 @@ export type CategorisedKeyStageAndSubject = z.infer< export class RAG { prisma: PrismaClientWithAccelerate; - constructor(prisma: PrismaClientWithAccelerate) { + private _chatMeta: OpenAICompletionWithLoggingOptions; + constructor( + prisma: PrismaClientWithAccelerate, + chatMeta: OpenAICompletionWithLoggingOptions, + ) { this.prisma = prisma; + this._chatMeta = chatMeta; } async categoriseKeyStageAndSubject( input: string, - chatMeta?: OpenAICompletionWithLoggingOptions, + chatMeta: OpenAICompletionWithLoggingOptions, ) { console.log("Categorise input", JSON.stringify(input)); @@ -394,7 +399,7 @@ Thank you and happy classifying!`; let plans: LessonPlan[] = []; try { - const rag = new RAG(this.prisma); + const rag = new RAG(this.prisma, { chatId }); plans = await rag.searchLessonPlans({ title, keyStage, @@ -605,7 +610,10 @@ Thank you and happy classifying!`; }, }); if (!foundKeyStage) { - const categorisation = await this.categoriseKeyStageAndSubject(keyStage); + const categorisation = await this.categoriseKeyStageAndSubject( + keyStage, + this._chatMeta, + ); if (categorisation.keyStage) { foundKeyStage = await this.prisma.subject.findFirst({ where: { @@ -661,7 +669,10 @@ Thank you and happy classifying!`; // console.log( // "No subject found. Categorise the input to try to work out what it is using categoriseKeyStageAndSubject", // ); - const categorisation = await this.categoriseKeyStageAndSubject(subject); + const categorisation = await this.categoriseKeyStageAndSubject( + subject, + this._chatMeta, + ); if (categorisation.subject) { foundSubject = await this.prisma.subject.findFirst({ where: {