Skip to content

Commit

Permalink
fix: only categorise initial user input once (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefl authored Nov 8, 2024
1 parent 7a34686 commit dd5bf71
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 2 deletions.
4 changes: 3 additions & 1 deletion apps/nextjs/src/app/api/chat/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export const defaultConfig: Config = {
prisma: globalPrisma,
createAila: async (options) => {
const webActionsPlugin = createWebActionsPlugin(globalPrisma);
return new Aila({
const createdAila = new Aila({
...options,
plugins: [...(options.plugins || []), webActionsPlugin],
prisma: options.prisma ?? globalPrisma,
Expand All @@ -26,5 +26,7 @@ export const defaultConfig: Config = {
userId: undefined,
},
});
await createdAila.initialise();
return createdAila;
},
};
4 changes: 3 additions & 1 deletion apps/nextjs/src/app/api/chat/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ describe("Chat API Route", () => {
chatCategoriser: mockChatCategoriser,
},
};
return new Aila(ailaConfig);
const ailaInstance = new Aila(ailaConfig);
await ailaInstance.initialise();
return ailaInstance;
}),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
prisma: {} as any,
Expand Down
95 changes: 95 additions & 0 deletions packages/aila/src/core/Aila.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { Polly } from "@pollyjs/core";

import { setupPolly } from "../../tests/mocks/setupPolly";
import type { AilaCategorisation } from "../features/categorisation";
import { MockCategoriser } from "../features/categorisation/categorisers/MockCategoriser";
import { Aila } from "./Aila";
import { AilaAuthenticationError } from "./AilaError";
Expand Down Expand Up @@ -76,6 +77,96 @@ describe("Aila", () => {
expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-2");
});

it("should use the categoriser to determine the lesson plan from user input if the lesson plan is not already set up", async () => {
const mockCategoriser = {
categorise: jest.fn().mockResolvedValue({
keyStage: "key-stage-2",
subject: "history",
title: "Roman Britain",
topic: "The Roman Empire",
}),
};

const ailaInstance = new Aila({
lessonPlan: {},
chat: {
id: "123",
userId: "user123",
messages: [
{
id: "1",
role: "user",
content:
"Create a lesson about Roman Britain for Key Stage 2 History",
},
],
},
options: {
usePersistence: false,
useRag: false,
useAnalytics: false,
useModeration: false,
},
plugins: [],
services: {
chatCategoriser: mockCategoriser as unknown as AilaCategorisation,
},
});

await ailaInstance.initialise();

expect(mockCategoriser.categorise).toHaveBeenCalledTimes(1);
expect(ailaInstance.lesson.plan.title).toBe("Roman Britain");
expect(ailaInstance.lesson.plan.subject).toBe("history");
expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-2");
});

it("should not use the categoriser to determine the lesson plan from user input if the lesson plan is already set up", async () => {
const mockCategoriser = {
categorise: jest.fn().mockResolvedValue({
keyStage: "key-stage-2",
subject: "history",
title: "Roman Britain",
topic: "The Roman Empire",
}),
};
const ailaInstance = new Aila({
lessonPlan: {
title: "Roman Britain",
subject: "history",
keyStage: "key-stage-2",
},
chat: {
id: "123",
userId: "user123",
messages: [
{
id: "1",
role: "user",
content:
"Create a lesson about Roman Britain for Key Stage 2 History",
},
],
},
options: {
usePersistence: false,
useRag: false,
useAnalytics: false,
useModeration: false,
},
plugins: [],
services: {
chatCategoriser: mockCategoriser as unknown as AilaCategorisation,
},
});

await ailaInstance.initialise();
expect(mockCategoriser.categorise).toHaveBeenCalledTimes(0);
expect(ailaInstance.lesson.plan.title).toBe("Roman Britain");
expect(ailaInstance.lesson.plan.subject).toBe("history");
expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-2");
});

// Calling initialise method successfully initializes the Aila instance
it("should successfully initialize the Aila instance when calling the initialise method, and by default not set the lesson plan to initial values", async () => {
const ailaInstance = new Aila({
Expand Down Expand Up @@ -226,6 +317,8 @@ describe("Aila", () => {
expect(ailaInstance.lesson.plan.subject).not.toBeDefined();
expect(ailaInstance.lesson.plan.keyStage).not.toBeDefined();

await ailaInstance.initialise();

await ailaInstance.generateSync({
input: "Glaciation",
});
Expand Down Expand Up @@ -295,6 +388,8 @@ describe("Aila", () => {
},
});

await ailaInstance.initialise();

await ailaInstance.generateSync({
input:
"Change the title to 'This should be ignored by the mocked service'",
Expand Down
18 changes: 18 additions & 0 deletions packages/aila/src/core/Aila.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import type {

const log = aiLogger("aila");
export class Aila implements AilaServices {
private _initialised: boolean = false; // We have a separate flag for this because we have an async initialise method which cannot be called in the constructor
private _analytics?: AilaAnalyticsFeature;
private _chat: AilaChatService;
private _errorReporter?: AilaErrorReportingFeature;
Expand Down Expand Up @@ -123,15 +124,30 @@ export class Aila implements AilaServices {
this._plugins = options.plugins;
}

private checkInitialised() {
if (!this._initialised) {
log.warn(
"Aila instance has not been initialised. Please call the initialise method before using the instance.",
);
throw new Error("Aila instance has not been initialised.");
}
}

// Initialization methods
public async initialise() {
if (this._initialised) {
log.info("Aila - already initialised");
return;
}
log.info("Aila - initialise");
this.checkUserIdPresentIfPersisting();
await this.loadChatIfPersisting();
const persistedLessonPlan = this._chat.persistedChat?.lessonPlan;
if (persistedLessonPlan) {
this._lesson.setPlan(persistedLessonPlan);
}
await this._lesson.setUpInitialLessonPlan(this._chat.messages);
this._initialised = true;
}

private initialiseOptions(options?: AilaOptions) {
Expand Down Expand Up @@ -246,6 +262,7 @@ export class Aila implements AilaServices {

// Generation methods
public async generateSync(opts: AilaGenerateLessonPlanOptions) {
this.checkInitialised();
const stream = await this.generate(opts);

const reader = stream.getReader();
Expand Down Expand Up @@ -273,6 +290,7 @@ export class Aila implements AilaServices {
keyStage,
topic,
}: AilaGenerateLessonPlanOptions) {
this.checkInitialised();
if (this._isShutdown) {
throw new AilaGenerationError(
"This Aila instance has been shut down and cannot be reused.",
Expand Down

0 comments on commit dd5bf71

Please sign in to comment.