Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: aila categoriser feature with chat ID and user ID #12

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions packages/aila/src/core/Aila.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Aila } from ".";
import { MockLLMService } from "../../tests/mocks/MockLLMService";
import { setupPolly } from "../../tests/mocks/setupPolly";
import { MockCategoriser } from "../features/categorisation/categorisers/MockCategoriser";
import { AilaAuthenticationError } from "./AilaError";

describe("Aila", () => {
Expand Down Expand Up @@ -285,4 +286,37 @@ describe("Aila", () => {
expect(ailaInstance.lesson.plan.title).toBe(newTitle);
}, 20000);
});

describe("categorisation", () => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding one example test here to show how we can pass in the mock categoriser during Aila initialisation

it("should use the provided MockCategoriser", async () => {
const mockedLessonPlan = {
title: "Mocked Lesson Plan",
subject: "Mocked Subject",
keyStage: "key-stage-3",
};

const mockCategoriser = new MockCategoriser({ mockedLessonPlan });

const ailaInstance = new Aila({
lessonPlan: {},
chat: { id: "123", userId: "user123" },
options: {
usePersistence: false,
useRag: false,
useAnalytics: false,
useModeration: false,
},
services: {
chatCategoriser: mockCategoriser,
},
plugins: [],
});

await ailaInstance.initialise();

expect(ailaInstance.lesson.plan.title).toBe("Mocked Lesson Plan");
expect(ailaInstance.lesson.plan.subject).toBe("Mocked Subject");
expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-3");
});
});
});
60 changes: 21 additions & 39 deletions packages/aila/src/core/Aila.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import {
DEFAULT_TEMPERATURE,
DEFAULT_RAG_LESSON_PLANS,
} from "../constants";
import { AilaCategorisation } from "../features/categorisation";
import {
AilaAnalyticsFeature,
AilaErrorReportingFeature,
AilaModerationFeature,
AilaPersistenceFeature,
AilaThreatDetectionFeature,
} from "../features/types";
import { fetchCategorisedInput } from "../utils/lessonPlan/fetchCategorisedInput";
import { AilaAuthenticationError, AilaGenerationError } from "./AilaError";
import { AilaFeatureFactory } from "./AilaFeatureFactory";
import {
Expand Down Expand Up @@ -43,8 +43,12 @@ export class Aila implements AilaServices {
private _threatDetection?: AilaThreatDetectionFeature;
private _prisma: PrismaClientWithAccelerate;
private _plugins: AilaPlugin[];
private _userId!: string | undefined;
private _chatId!: string;

constructor(options: AilaInitializationOptions) {
this._userId = options.chat.userId;
this._chatId = options.chat.id;
this._options = this.initialiseOptions(options.options);

this._chat = new AilaChat({
Expand All @@ -53,9 +57,21 @@ export class Aila implements AilaServices {
promptBuilder: options.promptBuilder,
});

this._lesson = new AilaLesson({ lessonPlan: options.lessonPlan ?? {} });
this._prisma = options.prisma ?? globalPrisma;

this._lesson = new AilaLesson({
aila: this,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have seen some fiddly bugs arise from passing aila: this as a prop, as the types aren't necessarily accurate (due to a quasi race condition on instantiation)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not saying don't do this, just highlighting because I think you were away when it was discovered

Copy link
Contributor Author

@stefl stefl Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay that's interesting! I see. So if you've not yet set some property on the instance before you pass it, then it could cause a problem.

lessonPlan: options.lessonPlan ?? {},
categoriser:
options.services?.chatCategoriser ??
new AilaCategorisation({
aila: this,
prisma: this._prisma,
chatId: this._chatId,
userId: this._userId,
}),
});

this._analytics = AilaFeatureFactory.createAnalytics(this, this._options);
this._moderation = AilaFeatureFactory.createModeration(this, this._options);
this._persistence = AilaFeatureFactory.createPersistence(
Expand All @@ -81,7 +97,7 @@ export class Aila implements AilaServices {
// Initialization methods
public async initialise() {
this.checkUserIdPresentIfPersisting();
await this.setUpInitialLessonPlan();
await this._lesson.setUpInitialLessonPlan(this._chat.messages);
}

private initialiseOptions(options?: AilaOptions) {
Expand Down Expand Up @@ -128,11 +144,11 @@ export class Aila implements AilaServices {
}

public get chatId() {
return this._chat.id;
return this._chatId;
}

public get userId() {
return this._chat.userId;
return this._userId;
}

public get messages() {
Expand Down Expand Up @@ -168,40 +184,6 @@ export class Aila implements AilaServices {
}
}

// Setup methods

// #TODO this is in the wrong place and should be
// moved to be hook into the initialisation of the lesson
// or chat
public async setUpInitialLessonPlan() {
const shouldRequestInitialState = Boolean(
!this.lesson.plan.subject &&
!this.lesson.plan.keyStage &&
!this.lesson.plan.title,
);

if (shouldRequestInitialState) {
const { title, subject, keyStage, topic } = this.lesson.plan;
const input = this.chat.messages.map((i) => i.content).join("\n\n");
const categorisationInput = [title, subject, keyStage, topic, input]
.filter((i) => i)
.join(" ");

const result = await fetchCategorisedInput({
input: categorisationInput,
prisma: this._prisma,
chatMeta: {
userId: this._chat.userId,
chatId: this._chat.id,
},
});

if (result) {
this.lesson.initialise(result);
}
}
}

// Generation methods
public async generateSync(opts: AilaGenerateLessonPlanOptions) {
const stream = await this.generate(opts);
Expand Down
1 change: 1 addition & 0 deletions packages/aila/src/core/AilaServices.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export interface AilaLessonService {
readonly hasSetInitialState: boolean;
applyPatches(patches: string): void;
initialise(plan: LooseLessonPlan): void;
setUpInitialLessonPlan(messages: Message[]): Promise<void>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moves the initial lesson plan setup to the Aila Lesson which makes more sense - resolves the TODO in the Aila class

}

export interface AilaChatService {
Expand Down
41 changes: 39 additions & 2 deletions packages/aila/src/core/lesson/AilaLesson.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
import { deepClone } from "fast-json-patch";

import { AilaCategorisation } from "../../features/categorisation/categorisers/AilaCategorisation";
import { AilaCategorisationFeature } from "../../features/types";
import {
PatchDocument,
applyLessonPlanPatch,
extractPatches,
} from "../../protocol/jsonPatchProtocol";
import { LooseLessonPlan } from "../../protocol/schema";
import { AilaLessonService } from "../AilaServices";
import { AilaLessonService, AilaServices } from "../AilaServices";
import { Message } from "../chat";

export class AilaLesson implements AilaLessonService {
private _aila: AilaServices;
private _plan: LooseLessonPlan;
private _hasSetInitialState = false;
private _appliedPatches: PatchDocument[] = [];
private _invalidPatches: PatchDocument[] = [];
private _categoriser: AilaCategorisationFeature;

constructor({ lessonPlan }: { lessonPlan?: LooseLessonPlan }) {
constructor({
aila,
lessonPlan,
categoriser,
}: {
aila: AilaServices;
lessonPlan?: LooseLessonPlan;
categoriser?: AilaCategorisationFeature;
}) {
this._aila = aila;
this._plan = lessonPlan ?? {};
this._categoriser =
categoriser ??
new AilaCategorisation({
aila,
userId: aila.userId,
chatId: aila.chatId,
});
}

public get plan(): LooseLessonPlan {
Expand Down Expand Up @@ -74,4 +95,20 @@ export class AilaLesson implements AilaLessonService {

this._plan = workingLessonPlan;
}

public async setUpInitialLessonPlan(messages: Message[]) {
const shouldCategoriseBasedOnInitialMessages = Boolean(
!this._plan.subject && !this._plan.keyStage && !this._plan.title,
);

// The initial lesson plan is blank, so we take the first messages
// and attempt to deduce the lesson plan key stage, subject, title and topic
if (shouldCategoriseBasedOnInitialMessages) {
const result = await this._categoriser.categorise(messages, this._plan);

if (result) {
this.initialise(result);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder {

private async fetchRelevantLessonPlans(): Promise<string> {
const noRelevantLessonPlans = "None";
const chatId = this._aila?.chatId;
const { chatId, userId } = this._aila;
if (!this._aila?.options.useRag) {
return noRelevantLessonPlans;
}
Expand All @@ -59,6 +59,8 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder {
this._aila?.options.numberOfLessonPlansInRag ??
DEFAULT_RAG_LESSON_PLANS,
prisma: globalPrisma,
chatId,
userId,
});
}, "Did not fetch RAG content. Continuing");

Expand Down
4 changes: 4 additions & 0 deletions packages/aila/src/core/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { AilaPersistence } from "../features/persistence";
import { AilaThreatDetector } from "../features/threatDetection";
import {
AilaAnalyticsFeature,
AilaCategorisationFeature,
AilaErrorReportingFeature,
AilaModerationFeature,
AilaThreatDetectionFeature,
Expand Down Expand Up @@ -67,4 +68,7 @@ export type AilaInitializationOptions = {
errorReporter?: AilaErrorReportingFeature;
promptBuilder?: AilaPromptBuilder;
plugins: AilaPlugin[];
services?: {
chatCategoriser?: AilaCategorisationFeature;
};
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { RAG } from "@oakai/core/src/rag";
import {
type PrismaClientWithAccelerate,
prisma as globalPrisma,
} from "@oakai/db";

import { AilaServices, Message } from "../../../core";
import { LooseLessonPlan } from "../../../protocol/schema";
import { AilaCategorisationFeature } from "../../types";

export class AilaCategorisation implements AilaCategorisationFeature {
private _aila: AilaServices;
private _prisma: PrismaClientWithAccelerate;
private _chatId: string;
private _userId: string | undefined;
constructor({
aila,
prisma,
chatId,
userId,
}: {
aila: AilaServices;
prisma?: PrismaClientWithAccelerate;
chatId: string;
userId?: string;
}) {
this._aila = aila;
this._prisma = prisma ?? globalPrisma;
this._chatId = chatId;
this._userId = userId;
}
public async categorise(
messages: Message[],
lessonPlan: LooseLessonPlan,
): Promise<LooseLessonPlan | undefined> {
const { title, subject, keyStage, topic } = lessonPlan;
const input = messages.map((i) => i.content).join("\n\n");
const categorisationInput = [title, subject, keyStage, topic, input]
.filter((i) => i)
.join(" ");

const result = await this.fetchCategorisedInput(
categorisationInput,
this._prisma,
);
return result;
}

private async fetchCategorisedInput(
input: string,
prisma: PrismaClientWithAccelerate,
): Promise<LooseLessonPlan | undefined> {
const rag = new RAG(prisma, {
chatId: this._chatId,
userId: this._userId,
});
const parsedCategorisation = await rag.categoriseKeyStageAndSubject(input, {
chatId: this._chatId,
userId: this._userId,
});
const { keyStage, subject, title, topic } = parsedCategorisation;
const plan: LooseLessonPlan = {
keyStage: keyStage ?? undefined,
subject: subject ?? undefined,
title: title ?? undefined,
topic: topic ?? undefined,
};
return plan;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { LooseLessonPlan } from "../../../protocol/schema";
import { AilaCategorisationFeature } from "../../types";

export class MockCategoriser implements AilaCategorisationFeature {
private _mockedLessonPlan: LooseLessonPlan | undefined;
constructor({
mockedLessonPlan,
}: {
mockedLessonPlan: LooseLessonPlan | undefined;
}) {
this._mockedLessonPlan = mockedLessonPlan;
}
public async categorise(): Promise<LooseLessonPlan | undefined> {
return this._mockedLessonPlan;
}
}
1 change: 1 addition & 0 deletions packages/aila/src/features/categorisation/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export { AilaCategorisation } from "./categorisers/AilaCategorisation";
9 changes: 6 additions & 3 deletions packages/aila/src/features/rag/AilaRag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@ import { LooseLessonPlan } from "../../protocol/schema";
import { minifyLessonPlanForRelevantLessons } from "../../utils/lessonPlan/minifyLessonPlanForRelevantLessons";

export class AilaRag {
private _aila?: AilaServices;
private _aila: AilaServices;
private _rag: RAG;
private _prisma: PrismaClientWithAccelerate;

constructor({
aila,
prisma,
}: {
aila?: AilaServices;
aila: AilaServices;
prisma?: PrismaClientWithAccelerate;
}) {
this._aila = aila;
this._prisma = prisma ?? globalPrisma;
this._rag = new RAG(this._prisma);
this._rag = new RAG(this._prisma, {
userId: aila.userId,
chatId: aila.chatId,
});
}

public async fetchRagContent({
Expand Down
7 changes: 7 additions & 0 deletions packages/aila/src/features/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,10 @@ export interface AilaErrorReportingFeature {
breadcrumbs?: { category: string; message: string },
): T | null;
}

export interface AilaCategorisationFeature {
categorise(
messages: Message[],
lessonPlan: LooseLessonPlan,
): Promise<LooseLessonPlan | undefined>;
}
Loading
Loading