-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: aila categoriser feature with chat ID and user ID #12
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,14 +6,14 @@ import { | |
DEFAULT_TEMPERATURE, | ||
DEFAULT_RAG_LESSON_PLANS, | ||
} from "../constants"; | ||
import { AilaCategorisation } from "../features/categorisation"; | ||
import { | ||
AilaAnalyticsFeature, | ||
AilaErrorReportingFeature, | ||
AilaModerationFeature, | ||
AilaPersistenceFeature, | ||
AilaThreatDetectionFeature, | ||
} from "../features/types"; | ||
import { fetchCategorisedInput } from "../utils/lessonPlan/fetchCategorisedInput"; | ||
import { AilaAuthenticationError, AilaGenerationError } from "./AilaError"; | ||
import { AilaFeatureFactory } from "./AilaFeatureFactory"; | ||
import { | ||
|
@@ -43,8 +43,12 @@ export class Aila implements AilaServices { | |
private _threatDetection?: AilaThreatDetectionFeature; | ||
private _prisma: PrismaClientWithAccelerate; | ||
private _plugins: AilaPlugin[]; | ||
private _userId!: string | undefined; | ||
private _chatId!: string; | ||
|
||
constructor(options: AilaInitializationOptions) { | ||
this._userId = options.chat.userId; | ||
this._chatId = options.chat.id; | ||
this._options = this.initialiseOptions(options.options); | ||
|
||
this._chat = new AilaChat({ | ||
|
@@ -53,9 +57,21 @@ export class Aila implements AilaServices { | |
promptBuilder: options.promptBuilder, | ||
}); | ||
|
||
this._lesson = new AilaLesson({ lessonPlan: options.lessonPlan ?? {} }); | ||
this._prisma = options.prisma ?? globalPrisma; | ||
|
||
this._lesson = new AilaLesson({ | ||
aila: this, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have seen some fiddly bugs arise from passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not saying don't do this, just highlighting because I think you were away when it was discovered There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay that's interesting! I see. So if you've not yet set some property on the instance before you pass it, then it could cause a problem. |
||
lessonPlan: options.lessonPlan ?? {}, | ||
categoriser: | ||
options.services?.chatCategoriser ?? | ||
new AilaCategorisation({ | ||
aila: this, | ||
prisma: this._prisma, | ||
chatId: this._chatId, | ||
userId: this._userId, | ||
}), | ||
}); | ||
|
||
this._analytics = AilaFeatureFactory.createAnalytics(this, this._options); | ||
this._moderation = AilaFeatureFactory.createModeration(this, this._options); | ||
this._persistence = AilaFeatureFactory.createPersistence( | ||
|
@@ -81,7 +97,7 @@ export class Aila implements AilaServices { | |
// Initialization methods | ||
public async initialise() { | ||
this.checkUserIdPresentIfPersisting(); | ||
await this.setUpInitialLessonPlan(); | ||
await this._lesson.setUpInitialLessonPlan(this._chat.messages); | ||
} | ||
|
||
private initialiseOptions(options?: AilaOptions) { | ||
|
@@ -128,11 +144,11 @@ export class Aila implements AilaServices { | |
} | ||
|
||
public get chatId() { | ||
return this._chat.id; | ||
return this._chatId; | ||
} | ||
|
||
public get userId() { | ||
return this._chat.userId; | ||
return this._userId; | ||
} | ||
|
||
public get messages() { | ||
|
@@ -168,40 +184,6 @@ export class Aila implements AilaServices { | |
} | ||
} | ||
|
||
// Setup methods | ||
|
||
// #TODO this is in the wrong place and should be | ||
// moved to be hook into the initialisation of the lesson | ||
// or chat | ||
public async setUpInitialLessonPlan() { | ||
const shouldRequestInitialState = Boolean( | ||
!this.lesson.plan.subject && | ||
!this.lesson.plan.keyStage && | ||
!this.lesson.plan.title, | ||
); | ||
|
||
if (shouldRequestInitialState) { | ||
const { title, subject, keyStage, topic } = this.lesson.plan; | ||
const input = this.chat.messages.map((i) => i.content).join("\n\n"); | ||
const categorisationInput = [title, subject, keyStage, topic, input] | ||
.filter((i) => i) | ||
.join(" "); | ||
|
||
const result = await fetchCategorisedInput({ | ||
input: categorisationInput, | ||
prisma: this._prisma, | ||
chatMeta: { | ||
userId: this._chat.userId, | ||
chatId: this._chat.id, | ||
}, | ||
}); | ||
|
||
if (result) { | ||
this.lesson.initialise(result); | ||
} | ||
} | ||
} | ||
|
||
// Generation methods | ||
public async generateSync(opts: AilaGenerateLessonPlanOptions) { | ||
const stream = await this.generate(opts); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ export interface AilaLessonService { | |
readonly hasSetInitialState: boolean; | ||
applyPatches(patches: string): void; | ||
initialise(plan: LooseLessonPlan): void; | ||
setUpInitialLessonPlan(messages: Message[]): Promise<void>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moves the initial lesson plan setup to the Aila Lesson which makes more sense - resolves the TODO in the Aila class |
||
} | ||
|
||
export interface AilaChatService { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import { RAG } from "@oakai/core/src/rag"; | ||
import { | ||
type PrismaClientWithAccelerate, | ||
prisma as globalPrisma, | ||
} from "@oakai/db"; | ||
|
||
import { AilaServices, Message } from "../../../core"; | ||
import { LooseLessonPlan } from "../../../protocol/schema"; | ||
import { AilaCategorisationFeature } from "../../types"; | ||
|
||
export class AilaCategorisation implements AilaCategorisationFeature { | ||
private _aila: AilaServices; | ||
private _prisma: PrismaClientWithAccelerate; | ||
private _chatId: string; | ||
private _userId: string | undefined; | ||
constructor({ | ||
aila, | ||
prisma, | ||
chatId, | ||
userId, | ||
}: { | ||
aila: AilaServices; | ||
prisma?: PrismaClientWithAccelerate; | ||
chatId: string; | ||
userId?: string; | ||
}) { | ||
this._aila = aila; | ||
this._prisma = prisma ?? globalPrisma; | ||
this._chatId = chatId; | ||
this._userId = userId; | ||
} | ||
public async categorise( | ||
messages: Message[], | ||
lessonPlan: LooseLessonPlan, | ||
): Promise<LooseLessonPlan | undefined> { | ||
const { title, subject, keyStage, topic } = lessonPlan; | ||
const input = messages.map((i) => i.content).join("\n\n"); | ||
const categorisationInput = [title, subject, keyStage, topic, input] | ||
.filter((i) => i) | ||
.join(" "); | ||
|
||
const result = await this.fetchCategorisedInput( | ||
categorisationInput, | ||
this._prisma, | ||
); | ||
return result; | ||
} | ||
|
||
private async fetchCategorisedInput( | ||
input: string, | ||
prisma: PrismaClientWithAccelerate, | ||
): Promise<LooseLessonPlan | undefined> { | ||
const rag = new RAG(prisma, { | ||
chatId: this._chatId, | ||
userId: this._userId, | ||
}); | ||
const parsedCategorisation = await rag.categoriseKeyStageAndSubject(input, { | ||
chatId: this._chatId, | ||
userId: this._userId, | ||
}); | ||
const { keyStage, subject, title, topic } = parsedCategorisation; | ||
const plan: LooseLessonPlan = { | ||
keyStage: keyStage ?? undefined, | ||
subject: subject ?? undefined, | ||
title: title ?? undefined, | ||
topic: topic ?? undefined, | ||
}; | ||
return plan; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import { LooseLessonPlan } from "../../../protocol/schema"; | ||
import { AilaCategorisationFeature } from "../../types"; | ||
|
||
export class MockCategoriser implements AilaCategorisationFeature { | ||
private _mockedLessonPlan: LooseLessonPlan | undefined; | ||
constructor({ | ||
mockedLessonPlan, | ||
}: { | ||
mockedLessonPlan: LooseLessonPlan | undefined; | ||
}) { | ||
this._mockedLessonPlan = mockedLessonPlan; | ||
} | ||
public async categorise(): Promise<LooseLessonPlan | undefined> { | ||
return this._mockedLessonPlan; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export { AilaCategorisation } from "./categorisers/AilaCategorisation"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding one example test here to show how we can pass in the mock categoriser during Aila initialisation