Skip to content

Commit

Permalink
fix: change plan parts embeddings to 256 vectprs (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
mantagen authored Sep 25, 2024
1 parent 638c851 commit 65f562c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 93 deletions.
79 changes: 0 additions & 79 deletions packages/core/src/models/lessonPlans.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ import {
PrismaClientWithAccelerate,
Subject,
} from "@oakai/db";
import { Prisma } from "@prisma/client";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { PrismaVectorStore } from "langchain/vectorstores/prisma";
import yaml from "yaml";

import { LLMResponseJsonSchema } from "../../../aila/src/protocol/jsonPatchProtocol";
Expand Down Expand Up @@ -57,11 +54,6 @@ type LessonPlanWithParts = LessonPlan & {
parts: LessonPlanPart[] | null;
};

interface FilterOptions {
key_stage_id?: object;
subject_id?: object;
}

export class LessonPlans {
private _rag: RAG;
private _prisma: PrismaClientWithAccelerate;
Expand Down Expand Up @@ -335,75 +327,4 @@ ${p.content}`,
WHERE id = ${id}`;
return result;
}

async search(
query: string,
keyStage: string | undefined,
subject: string | undefined,
perPage: number,
): Promise<LessonPlanWithLesson[]> {
const filter: FilterOptions = {};

if (keyStage) {
const keyStageRecord = await this._rag.fetchKeyStage(keyStage);

if (keyStageRecord) {
filter["key_stage_id"] = {
equals: keyStageRecord.id,
};
}
}
if (subject) {
const subjectRecord = await this._rag.fetchSubject(subject);
if (subjectRecord) {
filter["subject_id"] = {
equals: subjectRecord.id,
};
}
}
const vectorStore = PrismaVectorStore.withModel<LessonPlan>(
this._prisma,
).create(new OpenAIEmbeddings(), {
prisma: Prisma,
tableName: "lesson_plans" as "LessonPlan",
vectorColumnName: "embedding",
columns: {
id: PrismaVectorStore.IdColumn,
content: PrismaVectorStore.ContentColumn,
},
// @ts-expect-error TODO Bug in PrismaVectorStore which doesn't allow mapped column names
filter,
});
const result = await vectorStore.similaritySearch(query, perPage);

const lessonPlans: LessonPlanWithLesson[] =
await this._prisma.lessonPlan.findMany({
where: {
id: { in: result.map((r) => r.metadata.id) },
},
include: {
lesson: {
select: {
id: true,
slug: true,
title: true,
subjectId: true,
keyStageId: true,
isNewLesson: true,
newLessonContent: true,
},
},
},
});

const hydrated: LessonPlanWithLesson[] = [];
for (const entry of result) {
const lessonPlan = lessonPlans.find((ls) => ls.id === entry.metadata.id);
if (!lessonPlan) {
throw new Error("Lesson summary not found");
}
hydrated.push(lessonPlan);
}
return hydrated;
}
}
32 changes: 19 additions & 13 deletions packages/core/src/rag/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -733,20 +733,26 @@ Thank you and happy classifying!`;

const vectorStore = PrismaVectorStore.withModel<LessonPlanPart>(
this.prisma,
).create(new OpenAIEmbeddings(), {
prisma: Prisma,
tableName: "lesson_plan_parts" as "LessonPlanPart",
vectorColumnName: "embedding",
verbose: true,
openAIApiKey: process.env.OPENAI_API_KEY,
columns: {
id: PrismaVectorStore.IdColumn,
lesson_plan_id: PrismaVectorStore.IdColumn,
content: PrismaVectorStore.ContentColumn,
).create(
new OpenAIEmbeddings({
modelName: "text-embedding-3-large",
dimensions: 256,
}),
{
prisma: Prisma,
tableName: "lesson_plan_parts" as "LessonPlanPart",
vectorColumnName: "embedding",
verbose: true,
openAIApiKey: process.env.OPENAI_API_KEY,
columns: {
id: PrismaVectorStore.IdColumn,
lesson_plan_id: PrismaVectorStore.IdColumn,
content: PrismaVectorStore.ContentColumn,
},
// @ts-expect-error TODO Bug in PrismaVectorStore which doesn't allow mapped column names
filter,
},
// @ts-expect-error TODO Bug in PrismaVectorStore which doesn't allow mapped column names
filter,
});
);

const similaritySearchTerm = topic ? `${title}. ${topic}` : title;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This changes the embedding column from a 1536 vector to a 256 vector
ALTER TABLE lesson_plan_parts ADD COLUMN embedding_temp vector(256);
ALTER TABLE lesson_plan_parts DROP COLUMN embedding;
ALTER TABLE lesson_plan_parts RENAME COLUMN embedding_temp TO embedding;
2 changes: 1 addition & 1 deletion packages/db/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ model LessonPlanPart {
lessonPlan LessonPlan @relation(fields: [lessonPlanId], references: [id], onDelete: Cascade)
subject Subject? @relation(fields: [subjectId], references: [id], onDelete: Cascade)
keyStage KeyStage? @relation(fields: [keyStageId], references: [id], onDelete: Cascade)
embedding Unsupported("vector(1536)")?
embedding Unsupported("vector(256)")?
status LessonPlanPartStatus @default(PENDING)
@@unique([lessonPlanId, key])
Expand Down

0 comments on commit 65f562c

Please sign in to comment.