From dd5bf71a21421ac6e0beb60b4bab560cb159d877 Mon Sep 17 00:00:00 2001 From: Stef Lewandowski Date: Fri, 8 Nov 2024 16:35:08 +0000 Subject: [PATCH] fix: only categorise initial user input once (#348) --- apps/nextjs/src/app/api/chat/config.ts | 4 +- apps/nextjs/src/app/api/chat/route.test.ts | 4 +- packages/aila/src/core/Aila.test.ts | 95 ++++++++++++++++++++++ packages/aila/src/core/Aila.ts | 18 ++++ 4 files changed, 119 insertions(+), 2 deletions(-) diff --git a/apps/nextjs/src/app/api/chat/config.ts b/apps/nextjs/src/app/api/chat/config.ts index c3ba5b90d..4f32fe969 100644 --- a/apps/nextjs/src/app/api/chat/config.ts +++ b/apps/nextjs/src/app/api/chat/config.ts @@ -17,7 +17,7 @@ export const defaultConfig: Config = { prisma: globalPrisma, createAila: async (options) => { const webActionsPlugin = createWebActionsPlugin(globalPrisma); - return new Aila({ + const createdAila = new Aila({ ...options, plugins: [...(options.plugins || []), webActionsPlugin], prisma: options.prisma ?? globalPrisma, @@ -26,5 +26,7 @@ export const defaultConfig: Config = { userId: undefined, }, }); + await createdAila.initialise(); + return createdAila; }, }; diff --git a/apps/nextjs/src/app/api/chat/route.test.ts b/apps/nextjs/src/app/api/chat/route.test.ts index 1d976909c..74314022f 100644 --- a/apps/nextjs/src/app/api/chat/route.test.ts +++ b/apps/nextjs/src/app/api/chat/route.test.ts @@ -57,7 +57,9 @@ describe("Chat API Route", () => { chatCategoriser: mockChatCategoriser, }, }; - return new Aila(ailaConfig); + const ailaInstance = new Aila(ailaConfig); + await ailaInstance.initialise(); + return ailaInstance; }), // eslint-disable-next-line @typescript-eslint/no-explicit-any prisma: {} as any, diff --git a/packages/aila/src/core/Aila.test.ts b/packages/aila/src/core/Aila.test.ts index 12f725f3d..4f791bc7b 100644 --- a/packages/aila/src/core/Aila.test.ts +++ b/packages/aila/src/core/Aila.test.ts @@ -1,6 +1,7 @@ import type { Polly } from "@pollyjs/core"; import { setupPolly } from "../../tests/mocks/setupPolly"; +import type { AilaCategorisation } from "../features/categorisation"; import { MockCategoriser } from "../features/categorisation/categorisers/MockCategoriser"; import { Aila } from "./Aila"; import { AilaAuthenticationError } from "./AilaError"; @@ -76,6 +77,96 @@ describe("Aila", () => { expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-2"); }); + it("should use the categoriser to determine the lesson plan from user input if the lesson plan is not already set up", async () => { + const mockCategoriser = { + categorise: jest.fn().mockResolvedValue({ + keyStage: "key-stage-2", + subject: "history", + title: "Roman Britain", + topic: "The Roman Empire", + }), + }; + + const ailaInstance = new Aila({ + lessonPlan: {}, + chat: { + id: "123", + userId: "user123", + messages: [ + { + id: "1", + role: "user", + content: + "Create a lesson about Roman Britain for Key Stage 2 History", + }, + ], + }, + options: { + usePersistence: false, + useRag: false, + useAnalytics: false, + useModeration: false, + }, + plugins: [], + services: { + chatCategoriser: mockCategoriser as unknown as AilaCategorisation, + }, + }); + + await ailaInstance.initialise(); + + expect(mockCategoriser.categorise).toHaveBeenCalledTimes(1); + expect(ailaInstance.lesson.plan.title).toBe("Roman Britain"); + expect(ailaInstance.lesson.plan.subject).toBe("history"); + expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-2"); + }); + + it("should not use the categoriser to determine the lesson plan from user input if the lesson plan is already set up", async () => { + const mockCategoriser = { + categorise: jest.fn().mockResolvedValue({ + keyStage: "key-stage-2", + subject: "history", + title: "Roman Britain", + topic: "The Roman Empire", + }), + }; + const ailaInstance = new Aila({ + lessonPlan: { + title: "Roman Britain", + subject: "history", + keyStage: "key-stage-2", + }, + chat: { + id: "123", + userId: "user123", + messages: [ + { + id: "1", + role: "user", + content: + "Create a lesson about Roman Britain for Key Stage 2 History", + }, + ], + }, + options: { + usePersistence: false, + useRag: false, + useAnalytics: false, + useModeration: false, + }, + plugins: [], + services: { + chatCategoriser: mockCategoriser as unknown as AilaCategorisation, + }, + }); + + await ailaInstance.initialise(); + expect(mockCategoriser.categorise).toHaveBeenCalledTimes(0); + expect(ailaInstance.lesson.plan.title).toBe("Roman Britain"); + expect(ailaInstance.lesson.plan.subject).toBe("history"); + expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-2"); + }); + // Calling initialise method successfully initializes the Aila instance it("should successfully initialize the Aila instance when calling the initialise method, and by default not set the lesson plan to initial values", async () => { const ailaInstance = new Aila({ @@ -226,6 +317,8 @@ describe("Aila", () => { expect(ailaInstance.lesson.plan.subject).not.toBeDefined(); expect(ailaInstance.lesson.plan.keyStage).not.toBeDefined(); + await ailaInstance.initialise(); + await ailaInstance.generateSync({ input: "Glaciation", }); @@ -295,6 +388,8 @@ describe("Aila", () => { }, }); + await ailaInstance.initialise(); + await ailaInstance.generateSync({ input: "Change the title to 'This should be ignored by the mocked service'", diff --git a/packages/aila/src/core/Aila.ts b/packages/aila/src/core/Aila.ts index 886bc6089..da30dc549 100644 --- a/packages/aila/src/core/Aila.ts +++ b/packages/aila/src/core/Aila.ts @@ -43,6 +43,7 @@ import type { const log = aiLogger("aila"); export class Aila implements AilaServices { + private _initialised: boolean = false; // We have a separate flag for this because we have an async initialise method which cannot be called in the constructor private _analytics?: AilaAnalyticsFeature; private _chat: AilaChatService; private _errorReporter?: AilaErrorReportingFeature; @@ -123,8 +124,22 @@ export class Aila implements AilaServices { this._plugins = options.plugins; } + private checkInitialised() { + if (!this._initialised) { + log.warn( + "Aila instance has not been initialised. Please call the initialise method before using the instance.", + ); + throw new Error("Aila instance has not been initialised."); + } + } + // Initialization methods public async initialise() { + if (this._initialised) { + log.info("Aila - already initialised"); + return; + } + log.info("Aila - initialise"); this.checkUserIdPresentIfPersisting(); await this.loadChatIfPersisting(); const persistedLessonPlan = this._chat.persistedChat?.lessonPlan; @@ -132,6 +147,7 @@ export class Aila implements AilaServices { this._lesson.setPlan(persistedLessonPlan); } await this._lesson.setUpInitialLessonPlan(this._chat.messages); + this._initialised = true; } private initialiseOptions(options?: AilaOptions) { @@ -246,6 +262,7 @@ export class Aila implements AilaServices { // Generation methods public async generateSync(opts: AilaGenerateLessonPlanOptions) { + this.checkInitialised(); const stream = await this.generate(opts); const reader = stream.getReader(); @@ -273,6 +290,7 @@ export class Aila implements AilaServices { keyStage, topic, }: AilaGenerateLessonPlanOptions) { + this.checkInitialised(); if (this._isShutdown) { throw new AilaGenerationError( "This Aila instance has been shut down and cannot be reused.",