Skip to content

Commit

Permalink
Aila Categoriser feature with Chat ID and User ID
Browse files Browse the repository at this point in the history
  • Loading branch information
stefl committed Aug 28, 2024
1 parent d30461a commit 5951882
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 49 deletions.
34 changes: 34 additions & 0 deletions packages/aila/src/core/Aila.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Aila } from ".";
import { MockLLMService } from "../../tests/mocks/MockLLMService";
import { setupPolly } from "../../tests/mocks/setupPolly";
import { MockCategoriser } from "../features/categorisation/categorisers/MockCategoriser";
import { AilaAuthenticationError } from "./AilaError";

describe("Aila", () => {
Expand Down Expand Up @@ -285,4 +286,37 @@ describe("Aila", () => {
expect(ailaInstance.lesson.plan.title).toBe(newTitle);
}, 20000);
});

describe("categorisation", () => {
it("should use the provided MockCategoriser", async () => {
const mockedLessonPlan = {
title: "Mocked Lesson Plan",
subject: "Mocked Subject",
keyStage: "key-stage-3",
};

const mockCategoriser = new MockCategoriser({ mockedLessonPlan });

const ailaInstance = new Aila({
lessonPlan: {},
chat: { id: "123", userId: "user123" },
options: {
usePersistence: false,
useRag: false,
useAnalytics: false,
useModeration: false,
},
services: {
chatCategoriser: mockCategoriser,
},
plugins: [],
});

await ailaInstance.initialise();

expect(ailaInstance.lesson.plan.title).toBe("Mocked Lesson Plan");
expect(ailaInstance.lesson.plan.subject).toBe("Mocked Subject");
expect(ailaInstance.lesson.plan.keyStage).toBe("key-stage-3");
});
});
});
59 changes: 21 additions & 38 deletions packages/aila/src/core/Aila.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
DEFAULT_TEMPERATURE,
DEFAULT_RAG_LESSON_PLANS,
} from "../constants";
import { AilaCategorisation } from "../features/categorisation";
import {
AilaAnalyticsFeature,
AilaErrorReportingFeature,
Expand Down Expand Up @@ -43,8 +44,12 @@ export class Aila implements AilaServices {
private _threatDetection?: AilaThreatDetectionFeature;
private _prisma: PrismaClientWithAccelerate;
private _plugins: AilaPlugin[];
private _userId!: string | undefined;
private _chatId!: string;

constructor(options: AilaInitializationOptions) {
this._userId = options.chat.userId;
this._chatId = options.chat.id;
this._options = this.initialiseOptions(options.options);

this._chat = new AilaChat({
Expand All @@ -53,9 +58,21 @@ export class Aila implements AilaServices {
promptBuilder: options.promptBuilder,
});

this._lesson = new AilaLesson({ lessonPlan: options.lessonPlan ?? {} });
this._prisma = options.prisma ?? globalPrisma;

this._lesson = new AilaLesson({
aila: this,
lessonPlan: options.lessonPlan ?? {},
categoriser:
options.services?.chatCategoriser ??
new AilaCategorisation({
aila: this,
prisma: this._prisma,
chatId: this._chatId,
userId: this._userId,
}),
});

this._analytics = AilaFeatureFactory.createAnalytics(this, this._options);
this._moderation = AilaFeatureFactory.createModeration(this, this._options);
this._persistence = AilaFeatureFactory.createPersistence(
Expand All @@ -81,7 +98,7 @@ export class Aila implements AilaServices {
// Initialization methods
public async initialise() {
this.checkUserIdPresentIfPersisting();
await this.setUpInitialLessonPlan();
await this._lesson.setUpInitialLessonPlan(this._chat.messages);
}

private initialiseOptions(options?: AilaOptions) {
Expand Down Expand Up @@ -128,11 +145,11 @@ export class Aila implements AilaServices {
}

public get chatId() {
return this._chat.id;
return this._chatId;
}

public get userId() {
return this._chat.userId;
return this._userId;
}

public get messages() {
Expand Down Expand Up @@ -168,40 +185,6 @@ export class Aila implements AilaServices {
}
}

// Setup methods

// #TODO this is in the wrong place and should be
// moved to be hook into the initialisation of the lesson
// or chat
public async setUpInitialLessonPlan() {
const shouldRequestInitialState = Boolean(
!this.lesson.plan.subject &&
!this.lesson.plan.keyStage &&
!this.lesson.plan.title,
);

if (shouldRequestInitialState) {
const { title, subject, keyStage, topic } = this.lesson.plan;
const input = this.chat.messages.map((i) => i.content).join("\n\n");
const categorisationInput = [title, subject, keyStage, topic, input]
.filter((i) => i)
.join(" ");

const result = await fetchCategorisedInput({
input: categorisationInput,
prisma: this._prisma,
chatMeta: {
userId: this._chat.userId,
chatId: this._chat.id,
},
});

if (result) {
this.lesson.initialise(result);
}
}
}

// Generation methods
public async generateSync(opts: AilaGenerateLessonPlanOptions) {
const stream = await this.generate(opts);
Expand Down
1 change: 1 addition & 0 deletions packages/aila/src/core/AilaServices.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export interface AilaLessonService {
readonly hasSetInitialState: boolean;
applyPatches(patches: string): void;
initialise(plan: LooseLessonPlan): void;
setUpInitialLessonPlan(messages: Message[]): Promise<void>;
}

export interface AilaChatService {
Expand Down
41 changes: 39 additions & 2 deletions packages/aila/src/core/lesson/AilaLesson.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
import { deepClone } from "fast-json-patch";

import { AilaCategorisation } from "../../features/categorisation/categorisers/AilaCategorisation";
import { AilaCategorisationFeature } from "../../features/types";
import {
PatchDocument,
applyLessonPlanPatch,
extractPatches,
} from "../../protocol/jsonPatchProtocol";
import { LooseLessonPlan } from "../../protocol/schema";
import { AilaLessonService } from "../AilaServices";
import { AilaLessonService, AilaServices } from "../AilaServices";
import { Message } from "../chat";

export class AilaLesson implements AilaLessonService {
private _aila: AilaServices;
private _plan: LooseLessonPlan;
private _hasSetInitialState = false;
private _appliedPatches: PatchDocument[] = [];
private _invalidPatches: PatchDocument[] = [];
private _categoriser: AilaCategorisationFeature;

constructor({ lessonPlan }: { lessonPlan?: LooseLessonPlan }) {
constructor({
aila,
lessonPlan,
categoriser,
}: {
aila: AilaServices;
lessonPlan?: LooseLessonPlan;
categoriser?: AilaCategorisationFeature;
}) {
this._aila = aila;
this._plan = lessonPlan ?? {};
this._categoriser =
categoriser ??
new AilaCategorisation({
aila,
userId: aila.userId,
chatId: aila.chatId,
});
}

public get plan(): LooseLessonPlan {
Expand Down Expand Up @@ -74,4 +95,20 @@ export class AilaLesson implements AilaLessonService {

this._plan = workingLessonPlan;
}

public async setUpInitialLessonPlan(messages: Message[]) {
const shouldCategoriseBasedOnInitialMessages = Boolean(
!this._plan.subject && !this._plan.keyStage && !this._plan.title,
);

// The initial lesson plan is blank, so we take the first messages
// and attempt to deduce the lesson plan key stage, subject, title and topic
if (shouldCategoriseBasedOnInitialMessages) {
const result = await this._categoriser.categorise(messages, this._plan);

if (result) {
this.initialise(result);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder {

private async fetchRelevantLessonPlans(): Promise<string> {
const noRelevantLessonPlans = "None";
const chatId = this._aila?.chatId;
const { chatId, userId } = this._aila;
if (!this._aila?.options.useRag) {
return noRelevantLessonPlans;
}
Expand All @@ -59,6 +59,8 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder {
this._aila?.options.numberOfLessonPlansInRag ??
DEFAULT_RAG_LESSON_PLANS,
prisma: globalPrisma,
chatId,
userId,
});
}, "Did not fetch RAG content. Continuing");

Expand Down
4 changes: 4 additions & 0 deletions packages/aila/src/core/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { AilaPersistence } from "../features/persistence";
import { AilaThreatDetector } from "../features/threatDetection";
import {
AilaAnalyticsFeature,
AilaCategorisationFeature,
AilaErrorReportingFeature,
AilaModerationFeature,
AilaThreatDetectionFeature,
Expand Down Expand Up @@ -67,4 +68,7 @@ export type AilaInitializationOptions = {
errorReporter?: AilaErrorReportingFeature;
promptBuilder?: AilaPromptBuilder;
plugins: AilaPlugin[];
services?: {
chatCategoriser?: AilaCategorisationFeature;
};
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { RAG } from "@oakai/core/src/rag";
import {
type PrismaClientWithAccelerate,
prisma as globalPrisma,
} from "@oakai/db";

import { AilaServices, Message } from "../../../core";
import { LooseLessonPlan } from "../../../protocol/schema";
import { AilaCategorisationFeature } from "../../types";

export class AilaCategorisation implements AilaCategorisationFeature {
private _aila: AilaServices;
private _prisma: PrismaClientWithAccelerate;
private _chatId: string;
private _userId: string | undefined;
constructor({
aila,
prisma,
chatId,
userId,
}: {
aila: AilaServices;
prisma?: PrismaClientWithAccelerate;
chatId: string;
userId?: string;
}) {
this._aila = aila;
this._prisma = prisma ?? globalPrisma;
this._chatId = chatId;
this._userId = userId;
}
public async categorise(
messages: Message[],
lessonPlan: LooseLessonPlan,
): Promise<LooseLessonPlan | undefined> {
const { title, subject, keyStage, topic } = lessonPlan;
const input = messages.map((i) => i.content).join("\n\n");
const categorisationInput = [title, subject, keyStage, topic, input]
.filter((i) => i)
.join(" ");

const result = await this.fetchCategorisedInput(
categorisationInput,
this._prisma,
);
return result;
}

private async fetchCategorisedInput(
input: string,
prisma: PrismaClientWithAccelerate,
): Promise<LooseLessonPlan | undefined> {
const rag = new RAG(prisma, {
chatId: this._chatId,
userId: this._userId,
});
const parsedCategorisation = await rag.categoriseKeyStageAndSubject(input, {
chatId: this._chatId,
userId: this._userId,
});
const { keyStage, subject, title, topic } = parsedCategorisation;
const plan: LooseLessonPlan = {
keyStage: keyStage ?? undefined,
subject: subject ?? undefined,
title: title ?? undefined,
topic: topic ?? undefined,
};
return plan;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { LooseLessonPlan } from "../../../protocol/schema";
import { AilaCategorisationFeature } from "../../types";

export class MockCategoriser implements AilaCategorisationFeature {
private _mockedLessonPlan: LooseLessonPlan | undefined;
constructor({
mockedLessonPlan,
}: {
mockedLessonPlan: LooseLessonPlan | undefined;
}) {
this._mockedLessonPlan = mockedLessonPlan;
}
public async categorise(): Promise<LooseLessonPlan | undefined> {
return this._mockedLessonPlan;
}
}
1 change: 1 addition & 0 deletions packages/aila/src/features/categorisation/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export { AilaCategorisation } from "./categorisers/AilaCategorisation";
7 changes: 7 additions & 0 deletions packages/aila/src/features/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,10 @@ export interface AilaErrorReportingFeature {
breadcrumbs?: { category: string; message: string },
): T | null;
}

export interface AilaCategorisationFeature {
categorise(
messages: Message[],
lessonPlan: LooseLessonPlan,
): Promise<LooseLessonPlan | undefined>;
}
6 changes: 5 additions & 1 deletion packages/aila/src/utils/rag/fetchRagContent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export async function fetchRagContent({
id,
k = 5,
prisma,
chatId,
userId,
}: {
title: string;
subject?: string;
Expand All @@ -20,10 +22,12 @@ export async function fetchRagContent({
id: string;
k: number;
prisma: PrismaClientWithAccelerate;
chatId: string;
userId?: string;
}) {
let content = "[]";

const rag = new RAG(prisma);
const rag = new RAG(prisma, { chatId, userId });
const ragLessonPlans = await tryWithErrorReporting(
() => {
return title && keyStage && subject
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/models/lessonPlans.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export class LessonPlans {
private _prisma: PrismaClientWithAccelerate;
constructor(private readonly prisma: PrismaClientWithAccelerate) {
this._prisma = prisma;
this._rag = new RAG(this._prisma);
this._rag = new RAG(this._prisma, { chatId: "none" });
}

async embedAllParts(): Promise<void> {
Expand Down
Loading

0 comments on commit 5951882

Please sign in to comment.