From 523cbd82dbc109e2b86fd659242f1f6fb548a76b Mon Sep 17 00:00:00 2001 From: mantagen Date: Wed, 6 Nov 2024 12:06:25 +0000 Subject: [PATCH 01/13] feat: create ingest 'publish' script --- packages/aila/src/constants.ts | 2 + .../aila/src/utils/rag/fetchRagContent.ts | 22 +++++ packages/core/src/models/ragLessonPlans.ts | 40 +++++++++ packages/ingest/package.json | 1 + packages/ingest/src/db-helpers/step.ts | 1 + packages/ingest/src/index.ts | 4 + packages/ingest/src/steps/7-publish.ts | 84 +++++++++++++++++++ pnpm-lock.yaml | 14 ++++ 8 files changed, 168 insertions(+) create mode 100644 packages/core/src/models/ragLessonPlans.ts create mode 100644 packages/ingest/src/steps/7-publish.ts diff --git a/packages/aila/src/constants.ts b/packages/aila/src/constants.ts index fbcc27ab5..57032eb3a 100644 --- a/packages/aila/src/constants.ts +++ b/packages/aila/src/constants.ts @@ -5,6 +5,8 @@ export const DEFAULT_MODERATION_MODEL: OpenAI.Chat.ChatModel = "gpt-4o-2024-08-06"; export const DEFAULT_CATEGORISE_MODEL: OpenAI.Chat.ChatModel = "gpt-4o-2024-08-06"; +export const DEFAULT_EMBEDDING_MODEL: OpenAI.Embeddings.EmbeddingCreateParams["model"] = + "text-embedding-3-large"; export const DEFAULT_TEMPERATURE = 0.7; export const DEFAULT_MODERATION_TEMPERATURE = 0.7; export const DEFAULT_RAG_LESSON_PLANS = 5; diff --git a/packages/aila/src/utils/rag/fetchRagContent.ts b/packages/aila/src/utils/rag/fetchRagContent.ts index 3ffb709dc..286ba1c6d 100644 --- a/packages/aila/src/utils/rag/fetchRagContent.ts +++ b/packages/aila/src/utils/rag/fetchRagContent.ts @@ -1,5 +1,12 @@ +import { createOpenAIClient } from "@oakai/core/src/llm/openai"; +import { RagLessonPlans } from "@oakai/core/src/models/ragLessonPlans"; import { RAG } from "@oakai/core/src/rag"; +<<<<<<< Updated upstream import type { PrismaClientWithAccelerate } from "@oakai/db"; +======= +import { PrismaClientWithAccelerate } from "@oakai/db"; +import OpenAI from "openai"; +>>>>>>> Stashed changes import { tryWithErrorReporting } from "../../helpers/errorReporting"; import type { CompletedLessonPlan } from "../../protocol/schema"; @@ -33,6 +40,21 @@ export async function fetchRagContent({ chatId: string; userId?: string; }): Promise { + try { + // const openAiClient = createOpenAIClient({ + // app: "lesson-assistant", + // chatMeta: { + // userId, + // chatId, + // }, + // }); + const openAiClient = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); + const ragLessonPlans_ = new RagLessonPlans(prisma, openAiClient); + await ragLessonPlans_.getRelevantLessonPlans({ title }); + } catch (cause) { + console.log(cause); + } + const rag = new RAG(prisma, { chatId, userId }); const ragLessonPlans = await tryWithErrorReporting( () => { diff --git a/packages/core/src/models/ragLessonPlans.ts b/packages/core/src/models/ragLessonPlans.ts new file mode 100644 index 000000000..a95f94064 --- /dev/null +++ b/packages/core/src/models/ragLessonPlans.ts @@ -0,0 +1,40 @@ +import { PrismaClientWithAccelerate } from "@oakai/db"; +import OpenAI from "openai"; + +/** + * + */ +export class RagLessonPlans { + constructor( + private readonly prisma: PrismaClientWithAccelerate, + private readonly openai: OpenAI, + ) {} + + async getRelevantLessonPlans({ title }: { title: string }): Promise<{}[]> { + const embedding = await this.openai.embeddings.create({ + model: "text-embedding-3-large", + dimensions: 256, + input: title, + encoding_format: "float", + }); + + // console.log(JSON.stringify(embedding)); + + const queryEmbedding = `[${embedding.data[0]?.embedding.join(",")}]`; + const limit = 5; + // console.log(queryEmbedding); + + const results = await this.prisma.$queryRaw` + SELECT rag_lesson_plan_id + FROM rag.rag_lesson_plan_parts + ORDER BY embedding <-> ${queryEmbedding}::vector + LIMIT ${limit}; + `; + + console.log(results); + const all = await this.prisma.ragLessonPlan.findMany(); + console.log(all); + + return []; + } +} diff --git a/packages/ingest/package.json b/packages/ingest/package.json index 7fae73a65..19ccaf7a0 100644 --- a/packages/ingest/package.json +++ b/packages/ingest/package.json @@ -23,6 +23,7 @@ "@oakai/core": "*", "@oakai/db": "*", "@oakai/logger": "*", + "@paralleldrive/cuid2": "^2.2.2", "csv-parser": "^3.0.0", "graphql-request": "^6.1.0", "webvtt-parser": "^2.2.0", diff --git a/packages/ingest/src/db-helpers/step.ts b/packages/ingest/src/db-helpers/step.ts index 5a7c58221..0d9b3026f 100644 --- a/packages/ingest/src/db-helpers/step.ts +++ b/packages/ingest/src/db-helpers/step.ts @@ -6,6 +6,7 @@ export const STEP = [ "lesson_plan_generation", "chunking", "embedding", + "publishing", ] as const; const STEP_STATUS = ["started", "completed", "failed"] as const; diff --git a/packages/ingest/src/index.ts b/packages/ingest/src/index.ts index 67329af6f..44860a6d1 100644 --- a/packages/ingest/src/index.ts +++ b/packages/ingest/src/index.ts @@ -47,6 +47,10 @@ async function main() { case "embed-sync": await lpPartsEmbedSync({ prisma, log, ingestId }); break; + + case "publish": + // publish({ prisma, log }); + break; default: log.error("Unknown command"); process.exit(1); diff --git a/packages/ingest/src/steps/7-publish.ts b/packages/ingest/src/steps/7-publish.ts new file mode 100644 index 000000000..dcaa5554c --- /dev/null +++ b/packages/ingest/src/steps/7-publish.ts @@ -0,0 +1,84 @@ +import { Prisma, PrismaClientWithAccelerate } from "@oakai/db"; +import { createId } from "@paralleldrive/cuid2"; +import { isTruthy } from "remeda"; + +import { IngestError } from "../IngestError"; +import { getIngestById } from "../db-helpers/getIngestById"; +import { loadLessonsAndUpdateState } from "../db-helpers/loadLessonsAndUpdateState"; +import { Step, getPrevStep } from "../db-helpers/step"; +import { IngestLogger } from "../types"; + +const currentStep: Step = "publishing"; +const prevStep = getPrevStep(currentStep); + +/** + * Publish ingest lesson_plans and lesson_plan_parts to the rag schema + */ +export async function publish({ + prisma, + log, + ingestId, +}: { + prisma: PrismaClientWithAccelerate; + log: IngestLogger; + ingestId: string; +}) { + const ingest = await getIngestById({ prisma, ingestId }); + const lessons = await loadLessonsAndUpdateState({ + prisma, + ingestId, + prevStep, + currentStep, + }); + + const ragLessonPlans = lessons + .map((l) => + l.lessonPlan + ? { + ...l.lessonPlan, + ingestLessonId: l.id, + oakLessonId: l.oakLessonId, + subjectSlug: l.data.subjectSlug, + keyStageSlug: l.data.keyStageSlug, + } + : null, + ) + .filter(isTruthy) + .map((lp) => ({ + ingestLessonId: lp.ingestLessonId, + oakLessonId: lp.oakLessonId, + lessonPlan: lp.data as object, + subjectSlug: lp.subjectSlug, + keyStageSlug: lp.keyStageSlug, + })); + + await prisma.ragLessonPlan.createMany({ + data: ragLessonPlans, + }); + + const persistedRagLessonPlans = await prisma.ragLessonPlan.findMany({ + where: { + ingestLessonId: { + in: ragLessonPlans.map((lp) => lp.ingestLessonId), + }, + }, + }); + // const ingestLessonPlanParts = await prisma.ingestLessonPlanPart.findMany({}); + + // const ragLessonPlanIds = persistedRagLessonPlans.map((lp) => lp.id); + + // const keys = ragLessonPlans.flatMap((lp) => Object.keys(lp.lessonPlan)); + + // await prisma.$queryRaw` + // INSERT INTO rag.rag_lesson_plan_parts (id, rag_lesson_plan_id, key, value_text, value_json, embedding) + // SELECT * + // FROM UNNEST ( + // ${ragLessonPlanParts.map((p) => p.id)}::text[], + // ${ragLessonPlanParts.map((p) => p.ragLessonPlanId)}::text[], + // ${ragLessonPlanParts.map((p) => p.key)}::text[], + // ${ragLessonPlanParts.map((p) => p.valueText)}::text[], + // ${ragLessonPlanParts.map((p) => p.valueJson)}::jsonb[], + // ${ragLessonPlanParts.map((p) => `{${Array.from(p.embedding).join(",")}}`)}::vector(256)[] + // ); + // `; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2535e9118..9efdd498c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -839,6 +839,9 @@ importers: '@oakai/logger': specifier: '*' version: link:../logger + '@paralleldrive/cuid2': + specifier: ^2.2.2 + version: 2.2.2 csv-parser: specifier: ^3.0.0 version: 3.0.0 @@ -5052,6 +5055,11 @@ packages: requiresBuild: true optional: true + /@noble/hashes@1.5.0: + resolution: {integrity: sha512-1j6kQFb7QRru7eKN3ZDvRcP13rugwdxZqCjbiAVZfIJwgj2A65UmT4TgARXGlXgnRkORLTDTrO19ZErt7+QXgA==} + engines: {node: ^14.21.3 || >=16} + dev: false + /@nodelib/fs.scandir@2.1.5: resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==} engines: {node: '>= 8'} @@ -5684,6 +5692,12 @@ packages: '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) dev: false + /@paralleldrive/cuid2@2.2.2: + resolution: {integrity: sha512-ZOBkgDwEdoYVlSeRbYYXs0S9MejQofiVYoTbKzy/6GQa39/q5tQU2IX46+shYnUkpEl3wc+J6wRlar7r2EK2xA==} + dependencies: + '@noble/hashes': 1.5.0 + dev: false + /@peculiar/asn1-schema@2.3.13: resolution: {integrity: sha512-3Xq3a01WkHRZL8X04Zsfg//mGaA21xlL4tlVn4v2xGT0JStiztATRkMwa5b+f/HXmY2smsiLXYK46Gwgzvfg3g==} dependencies: From 1c06d78c03cd70f7e6969f10c97ae059e1acd1fc Mon Sep 17 00:00:00 2001 From: mantagen Date: Wed, 13 Nov 2024 18:21:19 +0000 Subject: [PATCH 02/13] remove merge markers --- packages/aila/src/utils/rag/fetchRagContent.ts | 4 ---- 1 file changed, 4 deletions(-) diff --git a/packages/aila/src/utils/rag/fetchRagContent.ts b/packages/aila/src/utils/rag/fetchRagContent.ts index 286ba1c6d..537ff9c25 100644 --- a/packages/aila/src/utils/rag/fetchRagContent.ts +++ b/packages/aila/src/utils/rag/fetchRagContent.ts @@ -1,12 +1,8 @@ import { createOpenAIClient } from "@oakai/core/src/llm/openai"; import { RagLessonPlans } from "@oakai/core/src/models/ragLessonPlans"; import { RAG } from "@oakai/core/src/rag"; -<<<<<<< Updated upstream import type { PrismaClientWithAccelerate } from "@oakai/db"; -======= -import { PrismaClientWithAccelerate } from "@oakai/db"; import OpenAI from "openai"; ->>>>>>> Stashed changes import { tryWithErrorReporting } from "../../helpers/errorReporting"; import type { CompletedLessonPlan } from "../../protocol/schema"; From be9d8d5659ea58699103230f4cda76b280746327 Mon Sep 17 00:00:00 2001 From: mantagen Date: Tue, 19 Nov 2024 10:21:22 +0000 Subject: [PATCH 03/13] add tests and key stage helper --- packages/aila/src/features/rag/index.test.ts | 97 +++++++++++++++ .../aila/src/features/rag/rag.fixtures.json | 112 ++++++++++++++++++ packages/core/src/rag/index.ts | 37 ++++-- packages/core/src/utils/shortenKeyStage.ts | 10 ++ 4 files changed, 243 insertions(+), 13 deletions(-) create mode 100644 packages/aila/src/features/rag/index.test.ts create mode 100644 packages/aila/src/features/rag/rag.fixtures.json create mode 100644 packages/core/src/utils/shortenKeyStage.ts diff --git a/packages/aila/src/features/rag/index.test.ts b/packages/aila/src/features/rag/index.test.ts new file mode 100644 index 000000000..8e5a6f6e3 --- /dev/null +++ b/packages/aila/src/features/rag/index.test.ts @@ -0,0 +1,97 @@ +/** + * These are essentially database snapshot tests that are used to test the RAG retrieval. + * Any given RAG data is updated infrequently, so if these tests fail, it's likely + * it's a signal that the RAG logic has changed. + * In this case, the new functionality should be tested and the snapshots updated. + */ +import { prisma } from "@oakai/db"; + +// import OpenAI from "openai"; +import ragFixtures from "./rag.fixtures.json"; + +// const openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); +async function pgVectorSearch({ + vector, + limit = 1, +}: { + vector: number[]; + limit?: number; +}): Promise { + const prismaResult = await prisma.$queryRaw` + SELECT + lesson_plan_id, + lesson_plan_parts.key, + lesson_plan_parts."content", + lesson_plans."content" ->> 'keyStage' as key_stage_slug, + lesson_plans."content" ->> 'subject' as subject_slug, + lesson_plan_parts.embedding <-> ${vector}::vector AS score + from lesson_plans join lesson_plan_parts on lesson_plan_parts.lesson_plan_id = lesson_plans.id + ORDER BY lesson_plan_parts.embedding <-> ${vector}::vector + LIMIT ${limit}`; + + return prismaResult; +} + +describe("RAG search", () => { + // it("should find lessons on 'xxx'", async () => { + // const lesson = { + // title: "The End of Roman Britain", + // subject: "history", + // keyStage: "key-stage-3", + // }; + + // const embedding = await openai.embeddings.create({ + // input: lesson.title, + // model: "text-embedding-3-large", + // dimensions: 256, + // }); + + // console.log(JSON.stringify(embedding.data[0]?.embedding)); + // const embeddingVector = embedding.data[0]?.embedding; + // const limit = 5; + + // const prismaResult = await prisma.$queryRaw` + // SELECT + // lesson_plan_id, + // lesson_plan_parts.key, + // lesson_plan_parts."content", + // lesson_plans."content" ->> 'keyStage' as key_stage_slug, + // lesson_plans."content" ->> 'subject' as subject_slug + // from lesson_plans join lesson_plan_parts on lesson_plan_parts.lesson_plan_id = lesson_plans.id + // ORDER BY lesson_plan_parts.embedding <-> ${embeddingVector}::vector + // LIMIT ${limit}`; + // }); + it("should find lessons on 'inventions of the industrial revolution'", async () => { + const vector = + ragFixtures["inventions of the industrial revolution"].embedding; + + const prismaResult = await pgVectorSearch({ vector }); + + expect(prismaResult).toMatchObject([ + { + lesson_plan_id: "cm23dy6hl05bjba3ufdd2scad", + key: "title", + content: "Inventions of the Industrial Revolution", + key_stage_slug: "ks3", + subject_slug: "history", + score: 0.0012731864653905422, + }, + ]); + }); + it("should find lessons on 'the end of roman britain'", async () => { + const vector = ragFixtures["the end of roman britain"].embedding; + + const prismaResult = await pgVectorSearch({ vector }); + + expect(prismaResult).toMatchObject([ + { + lesson_plan_id: "cm23dy6d9059bba3uq0axg5o7", + key: "topic", + content: "Anglo-Saxons and the end of Roman Britain", + key_stage_slug: "key-stage-2", + subject_slug: "history", + score: 0.6714155933066479, + }, + ]); + }); +}); diff --git a/packages/aila/src/features/rag/rag.fixtures.json b/packages/aila/src/features/rag/rag.fixtures.json new file mode 100644 index 000000000..af33497c9 --- /dev/null +++ b/packages/aila/src/features/rag/rag.fixtures.json @@ -0,0 +1,112 @@ +{ + "inventions of the industrial revolution": { + "embedding": [ + 0.024054011, 0.0109390505, -0.043416213, -0.038962398, -0.1653692, + -0.04736005, -0.14959385, -0.099207915, 0.0029982517, 0.082616605, + 0.0067402227, 0.058205605, 0.08098467, 0.041988272, 0.032060683, + 0.057797622, -0.06983313, -0.013480445, -0.038554415, -0.041920274, + -0.009349615, 0.00036017268, -0.026195923, -0.08853236, -0.047734037, + -0.03347162, 0.011109043, -0.021130132, -0.065855294, 0.14700995, + 0.029595783, -0.0437902, -0.0032043683, -0.027878853, 0.08084867, + 0.026127925, 0.048312012, 0.057933617, 0.10471569, -0.09628404, + 0.042362258, -0.026586907, 0.020807143, 0.014746893, -0.03699048, + 0.071057074, -0.13735434, -0.020314164, -0.11511926, 0.117227174, + -0.0049297973, -0.12130701, 0.001131516, 0.054125775, -0.04640809, + -0.058409598, 0.04348421, 0.13823831, 0.010038087, -0.13477045, + 0.0069314647, 0.035698533, -0.16400926, 0.052187853, -0.013276454, + -0.07636086, -0.071125075, -0.015758352, 0.110699445, -0.003935338, + 0.002284281, 0.06639927, 0.025379956, 0.041478295, 0.08322857, + -0.070921086, 0.019855183, 0.04695207, 0.051609877, 0.022864059, + 0.091932215, 0.01893722, 0.00899263, -0.06361138, 0.055213727, + 0.028541826, 0.04181828, -0.020246167, 0.121511005, 0.054635752, + 0.009001129, 0.09302017, 0.0534458, -0.0007793117, 0.010361074, + 0.030700738, 0.07506891, -0.059939533, -0.013344451, 0.035902523, + -0.023782022, 0.0315507, -0.013573942, -0.077652805, 0.058001615, + 0.012256496, -0.019532196, 0.08057669, -0.10131583, 0.09839195, + 0.023527032, -0.025396956, 0.04086632, -0.017373286, -0.060177524, + -0.038282424, 0.020620152, -0.105055675, 0.010157082, 0.01893722, + 0.06639927, 0.0066424767, 0.07262101, -0.013896928, -0.103015766, + -0.010641562, 0.0631354, -0.11348733, -0.03811243, -0.039336383, + -0.03211168, 0.022898057, 0.01602184, 0.051643874, 0.07683684, + -0.07377697, 0.089144334, -0.0021153504, -0.063577384, 0.0631354, + -0.12096702, -0.0018295498, -0.016327828, -0.06850718, -0.011007047, + 0.0061834957, -0.15530561, -0.005864759, -0.06031352, 0.059327558, + -0.110223465, -0.014228415, -0.083772555, 0.038214426, -0.01118554, + 0.014959385, 0.0153758675, 0.037330464, -0.016089838, -0.08098467, + 0.051167894, 0.037500456, -0.06765722, 0.06939115, 0.116547205, + -0.038248427, 0.07214503, 0.04253225, -0.029527785, 0.0340496, 0.07568089, + 0.011797515, 0.0050147935, -0.16754511, 0.08248061, -0.059667546, + -0.0048660496, -0.0135229435, 0.012477486, -0.06446135, -0.066161275, + 0.031074721, 0.033794608, 0.034508582, 0.067895204, 0.0047385553, + -0.0063067405, 0.036718488, 0.04277024, -0.07874076, -0.023068052, + -0.10233579, -0.014984883, 0.1650972, -0.037670452, -0.021504115, + -0.003478482, -0.027249878, -0.097575985, -0.0046620583, -0.076292865, + 0.10716359, 0.06541331, -0.052833825, 0.050929904, 0.010922051, + -0.024105009, 0.084520526, 0.03818043, -0.060041532, 0.018206252, + -0.024342999, -0.088668354, -0.063645385, -0.048006024, 0.040526334, + 0.11974307, 0.019804185, -0.08975631, -0.02798085, 0.0085293995, + 0.027606864, 0.028677821, 0.0046620583, -0.051405884, -0.027079886, + -0.10845554, 0.06361138, 0.049773954, 0.005053042, 0.035664532, + -0.15666555, -0.1264068, 0.052425843, -0.078400776, 0.055927698, + -0.043892194, 0.0052400343, -0.014313411, -0.022660067, 0.039778363, + -0.0021079134, -0.010871053, -0.08547249, 0.011933509, -0.015290871, + -0.009035128, 0.0947201, 0.05470375, -0.027997848, 0.058851577, + -0.0028473828, 0.010488569, -0.08554048, -0.07214503, 0.04103631 + ] + }, + "the end of roman britain": { + "embedding": [ + -0.06042437, -0.061291862, -0.015873492, 0.050047792, -0.06189244, + -0.021437138, -0.0015108178, 0.05979043, -0.082412034, 0.036034413, + -0.060891483, 0.08087724, -0.087950654, -0.04317456, 0.026208362, + -0.09776002, 0.0021218178, -0.000719437, 0.02659206, -0.08161127, + -0.043141197, 0.020069165, -0.03653489, -0.089618914, -0.09896117, + 0.023038667, -0.03146338, -0.06496204, -0.029328007, 0.09615849, + 0.021303678, -0.02472361, 0.17229787, 0.054351903, -0.016107049, + 0.13799845, 0.1247191, 0.10263134, -0.036868542, -0.04751204, 0.026942395, + -0.0071568345, -0.015523157, 0.04494292, 0.020669738, -0.017933793, + 0.018767923, -0.02742619, -0.059456777, -0.01778365, -0.04284091, + 0.0012793468, -0.032848034, 0.04784569, -0.048512995, -0.027759843, + 0.06906595, 0.036568254, -0.10630151, 0.008316275, -0.047311846, + 0.074871495, 0.11584396, 0.02499053, -0.091821015, -0.0139466515, + 0.08227857, -0.014146843, -0.07146825, -0.065228954, 0.033515338, + 0.019118257, 0.0041831615, -0.06185907, -0.12852274, -0.02298862, + 0.05848919, 0.03993814, -0.13312712, -0.008758364, -0.022421412, + -0.0076948483, 0.03011209, -0.018033888, 0.036267966, -0.007699019, + -0.054952476, -0.069132686, 0.04681137, -0.04861309, 0.05959024, + -0.0468781, -0.019051526, -0.20286039, 0.015398038, -0.07140152, + 0.035233647, 0.016332263, -0.07233574, -0.06479521, -0.042807546, + -0.045076378, 0.024456687, 0.05361787, 0.048846647, -0.05008116, + 0.00113963, -0.013963334, 0.022171173, 0.21180226, 0.12024816, 0.06696395, + 0.007653142, 0.038002957, -0.08354645, 0.0017495875, 0.028977672, + -0.057821885, -0.036634985, 0.04244053, -0.024973849, -0.043341387, + 0.10863708, -0.0014180208, -0.044842824, 0.072869584, -0.03610114, + 0.039204106, -0.05395152, 0.005613694, -0.10016232, 0.021303678, + 0.05144913, 0.19164968, 0.056287084, 0.0042749154, -0.039404295, + 0.019034844, 0.09242159, -0.024690244, -0.0076906774, 0.06289339, + 0.0012845601, 0.032464337, 0.003634721, 0.03570076, 0.043374754, + -0.028744116, -0.009675907, -0.05888957, 0.0026024852, 0.083880104, + -0.0044459123, -0.035233647, -0.037168827, -0.0336488, 0.00944235, + -0.033415243, -0.038303245, 0.04063881, -0.063160315, 0.04554349, + -0.017057955, 0.009008603, -0.015681641, -0.07440439, 0.12445218, + -0.03820315, -0.026425235, 0.008883484, 0.11397551, -0.06516223, + -0.054185078, -0.044208884, 0.07867513, 0.033765577, -0.09569138, + -0.0445759, -0.022171173, 0.09422331, 0.020002434, -0.017950475, + -0.06215936, -0.05568651, 0.045043014, -0.008337128, -0.016198803, + 0.0074279266, -0.0558867, -0.08574855, -0.07680668, -0.12191642, + 0.12084874, 0.04514311, -0.076339565, -0.014588932, 0.073203236, + 0.08181146, -0.0468781, -0.027125904, 0.13933305, 0.054285172, + -0.04327466, -0.08488106, 0.08367991, -0.03750248, 0.051849514, + 0.106701896, 0.15334643, -0.015131116, -0.0036242944, 0.04327466, + -0.036768444, -0.0026921541, -0.087750465, 0.07240248, -0.034833264, + 0.012178296, -0.015556523, 0.021820838, -0.03306491, -0.044409074, + -0.083146065, 0.04187332, 0.014680686, -0.052883834, 0.0075447047, + 0.08621567, 0.07827475, 0.053217486, -0.049547315, -0.00712764, + -0.11771241, -0.0070400564, 0.07146825, -0.0028402123, -0.106902085, + -0.0017214356, 0.07527188, 0.043341387, -0.025924757, 0.04631089, + -0.03171362, -0.10756939, 0.05425181, 0.0051424108, -0.07580572, + 0.037402384, -0.05888957, -0.003622209, 0.1297906, 0.09408985, + 0.0053551137, 0.10550075, -0.080743775, 0.005496916 + ] + } +} diff --git a/packages/core/src/rag/index.ts b/packages/core/src/rag/index.ts index ed115b234..dc1ba9ed6 100644 --- a/packages/core/src/rag/index.ts +++ b/packages/core/src/rag/index.ts @@ -21,6 +21,7 @@ import { DEFAULT_CATEGORISE_MODEL } from "../../../aila/src/constants"; import type { OpenAICompletionWithLoggingOptions } from "../../../aila/src/lib/openai/OpenAICompletionWithLogging"; import { OpenAICompletionWithLogging } from "../../../aila/src/lib/openai/OpenAICompletionWithLogging"; import type { JsonValue } from "../models/prompts"; +import { shortenKeyStage } from "../utils/shortenKeyStage"; import { slugify } from "../utils/slugify"; import { keyStages, subjects } from "../utils/subjects"; import { CategoriseKeyStageAndSubjectResponse } from "./categorisation"; @@ -552,19 +553,19 @@ Thank you and happy classifying!`; if (!keyStage) { return null; } - let cachedKeyStage: KeyStage | null; - try { - cachedKeyStage = await kv.get(`keyStage:${keyStage}`); - if (cachedKeyStage) { - return cachedKeyStage; - } - } catch (e) { - log.error( - "Error parsing cached keyStage. Continuing without cached value", - e, - ); - await kv.del(`keyStage:${keyStage}`); - } + // let cachedKeyStage: KeyStage | null; + // try { + // cachedKeyStage = await kv.get(`keyStage:${keyStage}`); + // if (cachedKeyStage) { + // return cachedKeyStage; + // } + // } catch (e) { + // log.error( + // "Error parsing cached keyStage. Continuing without cached value", + // e, + // ); + // await kv.del(`keyStage:${keyStage}`); + // } let foundKeyStage: KeyStage | null = null; foundKeyStage = await this.prisma.keyStage.findFirst({ @@ -575,6 +576,12 @@ Thank you and happy classifying!`; { slug: slugify(keyStage) }, { title: { equals: keyStage.toLowerCase(), mode: "insensitive" } }, { slug: { equals: keyStage.toLowerCase(), mode: "insensitive" } }, + { + slug: { + equals: shortenKeyStage(slugify(keyStage)), + mode: "insensitive", + }, + }, ], }, cacheStrategy: { ttl: 60 * 5, swr: 60 * 2 }, @@ -692,6 +699,8 @@ Thank you and happy classifying!`; }; } + log.info("Filter:", filter); + const vectorStore = PrismaVectorStore.withModel( this.prisma, ).create( @@ -723,6 +732,8 @@ Thank you and happy classifying!`; similaritySearchTerm, k * 5, // search for more records than we need ); + + log.info("Initial search result", result); } catch (e) { if (e instanceof TypeError && e.message.includes("join([])")) { log.warn("Caught TypeError with join([]), returning empty array"); diff --git a/packages/core/src/utils/shortenKeyStage.ts b/packages/core/src/utils/shortenKeyStage.ts new file mode 100644 index 000000000..87d177746 --- /dev/null +++ b/packages/core/src/utils/shortenKeyStage.ts @@ -0,0 +1,10 @@ +export const shortenKeyStage = (keyStage: string) => { + const keyStageMap: Record = { + "key-stage-1": "KS1", + "key-stage-2": "KS2", + "key-stage-3": "KS3", + "key-stage-4": "KS4", + }; + + return keyStageMap[keyStage] || keyStage; +}; From ad665b78f3e6ec94832fab1de9866a064aed57e0 Mon Sep 17 00:00:00 2001 From: mantagen Date: Wed, 20 Nov 2024 13:44:08 +0000 Subject: [PATCH 04/13] publish script --- packages/db/prisma/schema.prisma | 5 +- packages/ingest/src/index.ts | 3 +- packages/ingest/src/steps/7-publish.ts | 167 ++++++++++++++++++------- 3 files changed, 124 insertions(+), 51 deletions(-) diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index 7a9cc441a..cbf51ac7a 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -1118,14 +1118,15 @@ model IngestError { model RagLessonPlan { id String @id @default(cuid()) - oakLessonId Int @map("oak_lesson_id") + oakLessonId Int? @map("oak_lesson_id") + oakLessonSlug String @map("oak_lesson_slug") ingestLessonId String? @map("ingest_lesson_id") lessonPlan Json @map("lesson_plan") @db.JsonB subjectSlug String @map("subject_slug") keyStageSlug String @map("key_stage_slug") createdAt DateTime @default(now()) @map("created_at") updatedAt DateTime @updatedAt @map("updated_at") - ragLessonPlanPart RagLessonPlanPart[] + ragLessonPlanParts RagLessonPlanPart[] @@map("rag_lesson_plans") @@schema("rag") diff --git a/packages/ingest/src/index.ts b/packages/ingest/src/index.ts index 44860a6d1..34d1d84d4 100644 --- a/packages/ingest/src/index.ts +++ b/packages/ingest/src/index.ts @@ -13,6 +13,7 @@ import { lpBatchSync } from "./steps/3-lp-batch-sync"; import { lpChunking } from "./steps/4-lp-chunking"; import { lpPartsEmbedStart } from "./steps/5-lp-parts-embed-start"; import { lpPartsEmbedSync } from "./steps/6-lp-parts-embed-sync"; +import { publishToRag } from "./steps/7-publish"; const command = process.argv[2]; @@ -49,7 +50,7 @@ async function main() { break; case "publish": - // publish({ prisma, log }); + await publishToRag({ prisma, log, ingestId }); break; default: log.error("Unknown command"); diff --git a/packages/ingest/src/steps/7-publish.ts b/packages/ingest/src/steps/7-publish.ts index dcaa5554c..574f9e12e 100644 --- a/packages/ingest/src/steps/7-publish.ts +++ b/packages/ingest/src/steps/7-publish.ts @@ -1,12 +1,17 @@ -import { Prisma, PrismaClientWithAccelerate } from "@oakai/db"; -import { createId } from "@paralleldrive/cuid2"; +import { + LessonPlanSchema, + type LooseLessonPlan, +} from "@oakai/aila/src/protocol/schema"; +import type { PrismaClientWithAccelerate } from "@oakai/db"; +import type { Prisma } from "@oakai/db"; import { isTruthy } from "remeda"; import { IngestError } from "../IngestError"; -import { getIngestById } from "../db-helpers/getIngestById"; import { loadLessonsAndUpdateState } from "../db-helpers/loadLessonsAndUpdateState"; -import { Step, getPrevStep } from "../db-helpers/step"; -import { IngestLogger } from "../types"; +import type { Step } from "../db-helpers/step"; +import { getPrevStep } from "../db-helpers/step"; +import type { IngestLogger } from "../types"; +import { chunkAndPromiseAll } from "../utils/chunkAndPromiseAll"; const currentStep: Step = "publishing"; const prevStep = getPrevStep(currentStep); @@ -14,7 +19,7 @@ const prevStep = getPrevStep(currentStep); /** * Publish ingest lesson_plans and lesson_plan_parts to the rag schema */ -export async function publish({ +export async function publishToRag({ prisma, log, ingestId, @@ -23,7 +28,8 @@ export async function publish({ log: IngestLogger; ingestId: string; }) { - const ingest = await getIngestById({ prisma, ingestId }); + log.info("Publishing lesson plans and parts to RAG schema"); + // const ingest = await getIngestById({ prisma, ingestId }); const lessons = await loadLessonsAndUpdateState({ prisma, ingestId, @@ -31,54 +37,119 @@ export async function publish({ currentStep, }); - const ragLessonPlans = lessons - .map((l) => - l.lessonPlan - ? { - ...l.lessonPlan, - ingestLessonId: l.id, - oakLessonId: l.oakLessonId, - subjectSlug: l.data.subjectSlug, - keyStageSlug: l.data.keyStageSlug, - } - : null, - ) - .filter(isTruthy) - .map((lp) => ({ - ingestLessonId: lp.ingestLessonId, - oakLessonId: lp.oakLessonId, - lessonPlan: lp.data as object, - subjectSlug: lp.subjectSlug, - keyStageSlug: lp.keyStageSlug, - })); - - await prisma.ragLessonPlan.createMany({ + log.info(`Loaded ${lessons.length} lessons`); + + const ragLessonPlans: { + oakLessonId?: number; + oakLessonSlug: string; + ingestLessonId?: string; + subjectSlug: string; + keyStageSlug: string; + lessonPlan: LooseLessonPlan; + }[] = []; + + for (const lesson of lessons) { + if (!lesson.lessonPlan) { + throw new IngestError("Lessin is missing lesson plan", { + ingestId, + lessonId: lesson.id, + }); + } + + const lessonPlan = LessonPlanSchema.parse(lesson.lessonPlan); + ragLessonPlans.push({ + oakLessonId: lesson.oakLessonId, + oakLessonSlug: lesson.data.lessonSlug, + ingestLessonId: lesson.id, + subjectSlug: lesson.data.subjectSlug, + keyStageSlug: lesson.data.keyStageSlug, + lessonPlan, + }); + } + + /** + * Add lesson plans to RAG schema + */ + await chunkAndPromiseAll({ data: ragLessonPlans, + chunkSize: 100, + fn: async (data) => { + await prisma.ragLessonPlan.createMany({ + data, + }); + }, }); + log.info(`Written ${ragLessonPlans.length} lesson plans`); + + /** + * Fetch persisted lesson plans (with ids) + */ const persistedRagLessonPlans = await prisma.ragLessonPlan.findMany({ where: { ingestLessonId: { - in: ragLessonPlans.map((lp) => lp.ingestLessonId), + in: ragLessonPlans.map((lp) => lp.ingestLessonId).filter(isTruthy), }, }, + select: { + id: true, + ingestLessonId: true, + }, }); - // const ingestLessonPlanParts = await prisma.ingestLessonPlanPart.findMany({}); - - // const ragLessonPlanIds = persistedRagLessonPlans.map((lp) => lp.id); - - // const keys = ragLessonPlans.flatMap((lp) => Object.keys(lp.lessonPlan)); - - // await prisma.$queryRaw` - // INSERT INTO rag.rag_lesson_plan_parts (id, rag_lesson_plan_id, key, value_text, value_json, embedding) - // SELECT * - // FROM UNNEST ( - // ${ragLessonPlanParts.map((p) => p.id)}::text[], - // ${ragLessonPlanParts.map((p) => p.ragLessonPlanId)}::text[], - // ${ragLessonPlanParts.map((p) => p.key)}::text[], - // ${ragLessonPlanParts.map((p) => p.valueText)}::text[], - // ${ragLessonPlanParts.map((p) => p.valueJson)}::jsonb[], - // ${ragLessonPlanParts.map((p) => `{${Array.from(p.embedding).join(",")}}`)}::vector(256)[] - // ); - // `; + + const ragLessonPlanParts: { + ragLessonPlanId: string; + key: string; + valueText: string; + valueJson: Prisma.JsonValue; + embedding: number[]; + }[] = []; + + for (const ragLessonPlan of persistedRagLessonPlans) { + const ragLessonPlanId = ragLessonPlan.id; + const ingestLessonId = ragLessonPlan.ingestLessonId; + const lesson = lessons.find((l) => l.id === ingestLessonId); + if (!lesson) { + throw new IngestError("Lesson not found", { + ingestId, + lessonId: ingestLessonId ?? "NO_ID_PROVIDED", + }); + } + + for (const part of lesson.lessonPlanParts) { + ragLessonPlanParts.push({ + ragLessonPlanId, + key: part.key, + valueText: part.valueText, + valueJson: part.valueJson, + embedding: [], + }); + } + } + + log.info(`Writing ${ragLessonPlanParts.length} lesson plan parts`); + + /** + * Add lesson plan parts to RAG schema + */ + await chunkAndPromiseAll({ + data: ragLessonPlanParts, + chunkSize: 100, + fn: async (data) => { + // Need to use $queryRaw because Prisma doesn't support the vector type + await prisma.$queryRaw` + INSERT INTO rag.rag_lesson_plan_parts (rag_lesson_plan_id, key, value_text, value_json, embedding) + SELECT * + FROM UNNEST ( + ${data.map((p) => p.ragLessonPlanId)}::text[], + ${data.map((p) => p.key)}::text[], + ${data.map((p) => p.valueText)}::text[], + ${data.map((p) => p.valueJson)}::jsonb[], + ${data.map((p) => `{${Array.from(p.embedding).join(",")}}`)}::vector(256)[] + ); + `; + }, + }); + + log.info("Published lesson plans and parts to RAG schema"); } From 63a3ba01818313b35cad352a1a3c6aa5fbb365f2 Mon Sep 17 00:00:00 2001 From: mantagen Date: Wed, 20 Nov 2024 17:21:25 +0000 Subject: [PATCH 05/13] fix publish step and retrieval --- packages/core/src/models/ragLessonPlans.ts | 38 ++++++++++---- packages/ingest/src/steps/7-publish.ts | 61 ++++++++++++++++------ 2 files changed, 73 insertions(+), 26 deletions(-) diff --git a/packages/core/src/models/ragLessonPlans.ts b/packages/core/src/models/ragLessonPlans.ts index a95f94064..51baa0cd1 100644 --- a/packages/core/src/models/ragLessonPlans.ts +++ b/packages/core/src/models/ragLessonPlans.ts @@ -1,6 +1,8 @@ -import { PrismaClientWithAccelerate } from "@oakai/db"; -import OpenAI from "openai"; +import type { PrismaClientWithAccelerate } from "@oakai/db"; +import { aiLogger } from "@oakai/logger"; +import type OpenAI from "openai"; +const log = aiLogger("rag"); /** * */ @@ -18,22 +20,36 @@ export class RagLessonPlans { encoding_format: "float", }); - // console.log(JSON.stringify(embedding)); - const queryEmbedding = `[${embedding.data[0]?.embedding.join(",")}]`; - const limit = 5; - // console.log(queryEmbedding); + const limit = 50; - const results = await this.prisma.$queryRaw` - SELECT rag_lesson_plan_id + const startAt = new Date(); + log.info(`Fetching relevant lesson plans for ${title}`); + let results: { rag_lesson_plan_id: string }[] = []; + results = await this.prisma.$queryRaw` + SELECT rag_lesson_plan_id, key, value_text, embedding <-> ${queryEmbedding}::vector FROM rag.rag_lesson_plan_parts ORDER BY embedding <-> ${queryEmbedding}::vector LIMIT ${limit}; `; - console.log(results); - const all = await this.prisma.ragLessonPlan.findMany(); - console.log(all); + const endAt = new Date(); + log.info( + `Fetched ${results.length} lesson plans in ${endAt.getTime() - startAt.getTime()}ms`, + ); + + log.info(JSON.stringify(results, null, 2)); + const all = await this.prisma.ragLessonPlan.findMany({ + where: { + id: { + in: results.map((r) => r.rag_lesson_plan_id), + }, + }, + select: { + oakLessonSlug: true, + }, + }); + log.info(all); return []; } diff --git a/packages/ingest/src/steps/7-publish.ts b/packages/ingest/src/steps/7-publish.ts index 574f9e12e..708ee6c77 100644 --- a/packages/ingest/src/steps/7-publish.ts +++ b/packages/ingest/src/steps/7-publish.ts @@ -4,7 +4,10 @@ import { } from "@oakai/aila/src/protocol/schema"; import type { PrismaClientWithAccelerate } from "@oakai/db"; import type { Prisma } from "@oakai/db"; +import { createId } from "@paralleldrive/cuid2"; import { isTruthy } from "remeda"; +import invariant from "tiny-invariant"; +import { z } from "zod"; import { IngestError } from "../IngestError"; import { loadLessonsAndUpdateState } from "../db-helpers/loadLessonsAndUpdateState"; @@ -56,7 +59,8 @@ export async function publishToRag({ }); } - const lessonPlan = LessonPlanSchema.parse(lesson.lessonPlan); + const lessonPlan = LessonPlanSchema.parse(lesson.lessonPlan.data); + ragLessonPlans.push({ oakLessonId: lesson.oakLessonId, oakLessonSlug: lesson.data.lessonSlug, @@ -72,7 +76,7 @@ export async function publishToRag({ */ await chunkAndPromiseAll({ data: ragLessonPlans, - chunkSize: 100, + chunkSize: 500, fn: async (data) => { await prisma.ragLessonPlan.createMany({ data, @@ -109,6 +113,27 @@ export async function publishToRag({ const ragLessonPlanId = ragLessonPlan.id; const ingestLessonId = ragLessonPlan.ingestLessonId; const lesson = lessons.find((l) => l.id === ingestLessonId); + + const lessonPlanParts = await prisma.$queryRaw` + SELECT key, value_text, value_json, embedding::text + FROM ingest.ingest_lesson_plan_part + WHERE lesson_id = ${ingestLessonId} + `; + const parsedLessonPlanParts = z + .array( + z.object({ + key: z.string(), + value_text: z.string(), + value_json: z.union([ + z.string(), + z.array(z.union([z.string(), z.object({}).passthrough()])), + z.object({}).passthrough(), + ]), + embedding: z.string(), + }), + ) + .parse(lessonPlanParts); + if (!lesson) { throw new IngestError("Lesson not found", { ingestId, @@ -116,13 +141,13 @@ export async function publishToRag({ }); } - for (const part of lesson.lessonPlanParts) { + for (const part of parsedLessonPlanParts) { ragLessonPlanParts.push({ ragLessonPlanId, key: part.key, - valueText: part.valueText, - valueJson: part.valueJson, - embedding: [], + valueText: part.value_text, + valueJson: part.value_json, + embedding: part.embedding.slice(1, -1).split(",").map(Number), }); } } @@ -134,20 +159,26 @@ export async function publishToRag({ */ await chunkAndPromiseAll({ data: ragLessonPlanParts, - chunkSize: 100, + chunkSize: 500, fn: async (data) => { + const now = new Date().toISOString(); // Need to use $queryRaw because Prisma doesn't support the vector type await prisma.$queryRaw` - INSERT INTO rag.rag_lesson_plan_parts (rag_lesson_plan_id, key, value_text, value_json, embedding) - SELECT * - FROM UNNEST ( - ${data.map((p) => p.ragLessonPlanId)}::text[], - ${data.map((p) => p.key)}::text[], - ${data.map((p) => p.valueText)}::text[], - ${data.map((p) => p.valueJson)}::jsonb[], - ${data.map((p) => `{${Array.from(p.embedding).join(",")}}`)}::vector(256)[] + INSERT INTO rag.rag_lesson_plan_parts (id, rag_lesson_plan_id, key, value_text, value_json, created_at, updated_at, embedding) + SELECT * + FROM UNNEST ( + ARRAY[${data.map(() => createId())}]::text[], + ARRAY[${data.map((p) => p.ragLessonPlanId)}]::text[], + ARRAY[${data.map((p) => p.key)}]::text[], + ARRAY[${data.map((p) => p.valueText)}]::text[], + ARRAY[${data.map((p) => JSON.stringify(p.valueJson))}]::jsonb[], + ARRAY[${data.map(() => now)}]::timestamp[], + ARRAY[${data.map(() => now)}]::timestamp[], + ARRAY[${data.map((p) => `[${p.embedding.join(",")}]`)}]::vector(256)[] ); `; + + log.info(prisma.$queryRawUnsafe.toString()); }, }); From 05ba3c579917b3b3ab2f4c08789f33a88302a42c Mon Sep 17 00:00:00 2001 From: mantagen Date: Mon, 25 Nov 2024 18:14:27 +0000 Subject: [PATCH 06/13] rag lesson plan indexes --- .../aila/src/utils/rag/fetchRagContent.ts | 25 +++++--- packages/aila/src/utils/rag/parseKeyStage.ts | 20 ++++++ packages/aila/src/utils/rag/parseSubjects.ts | 17 ++++++ packages/core/src/models/ragLessonPlans.ts | 61 +++++++++++++------ .../parts/interactingWithTheUser.ts | 2 + packages/db/package.json | 2 + packages/db/prisma/additions/README.md | 3 + .../rag_lesson_plans_ivfflat_index.sql | 4 ++ .../rag_lesson_plans_unique_slug_index.sql | 3 + packages/db/prisma/schema.prisma | 22 ++++--- packages/ingest/src/_data/.gitignore | 3 +- packages/ingest/src/config/ingestConfig.ts | 2 + .../parseBatchLessonPlan.ts | 3 + .../generate-lesson-plans/parseKeyStage.ts | 20 ++++++ 14 files changed, 148 insertions(+), 39 deletions(-) create mode 100644 packages/aila/src/utils/rag/parseKeyStage.ts create mode 100644 packages/aila/src/utils/rag/parseSubjects.ts create mode 100644 packages/db/prisma/additions/README.md create mode 100644 packages/db/prisma/additions/rag_lesson_plans_ivfflat_index.sql create mode 100644 packages/db/prisma/additions/rag_lesson_plans_unique_slug_index.sql create mode 100644 packages/ingest/src/generate-lesson-plans/parseKeyStage.ts diff --git a/packages/aila/src/utils/rag/fetchRagContent.ts b/packages/aila/src/utils/rag/fetchRagContent.ts index 537ff9c25..b0599db8f 100644 --- a/packages/aila/src/utils/rag/fetchRagContent.ts +++ b/packages/aila/src/utils/rag/fetchRagContent.ts @@ -1,4 +1,3 @@ -import { createOpenAIClient } from "@oakai/core/src/llm/openai"; import { RagLessonPlans } from "@oakai/core/src/models/ragLessonPlans"; import { RAG } from "@oakai/core/src/rag"; import type { PrismaClientWithAccelerate } from "@oakai/db"; @@ -7,6 +6,8 @@ import OpenAI from "openai"; import { tryWithErrorReporting } from "../../helpers/errorReporting"; import type { CompletedLessonPlan } from "../../protocol/schema"; import { minifyLessonPlanForRelevantLessons } from "../lessonPlan/minifyLessonPlanForRelevantLessons"; +import { parseKeyStage } from "./parseKeyStage"; +import { parseSubjects } from "./parseSubjects"; export type RagLessonPlan = Omit< CompletedLessonPlan, @@ -37,18 +38,22 @@ export async function fetchRagContent({ userId?: string; }): Promise { try { - // const openAiClient = createOpenAIClient({ - // app: "lesson-assistant", - // chatMeta: { - // userId, - // chatId, - // }, - // }); const openAiClient = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); const ragLessonPlans_ = new RagLessonPlans(prisma, openAiClient); - await ragLessonPlans_.getRelevantLessonPlans({ title }); + const keyStageSlugs = keyStage ? [parseKeyStage(keyStage)] : null; + const subjectSlugs = subject ? parseSubjects(subject) : null; + const results = await ragLessonPlans_.getRelevantLessonPlans({ + title, + subjectSlugs, + keyStageSlugs, + }); + + return results.map((result) => ({ + id: result.rag_lesson_plan_id, + ...result.lesson_plan, + })); } catch (cause) { - console.log(cause); + throw new Error("Failed to fetch RAG content", { cause }); } const rag = new RAG(prisma, { chatId, userId }); diff --git a/packages/aila/src/utils/rag/parseKeyStage.ts b/packages/aila/src/utils/rag/parseKeyStage.ts new file mode 100644 index 000000000..9daa83aaa --- /dev/null +++ b/packages/aila/src/utils/rag/parseKeyStage.ts @@ -0,0 +1,20 @@ +const keyStageMap: Record = { + 1: "ks1", + 2: "ks2", + 3: "ks3", + 4: "ks4", + 5: "ks5", + keystage1: "ks1", + keystage2: "ks2", + keystage3: "ks3", + keystage4: "ks4", + keystage5: "ks5", + eyfs: "early-years-foundation-stage", +}; + +export function parseKeyStage(maybeKeyStage: string): string { + maybeKeyStage = maybeKeyStage.toLowerCase().replace(/[^a-z0-9]/g, ""); + const keyStageSlug = keyStageMap[maybeKeyStage]; + + return keyStageSlug ?? maybeKeyStage; +} diff --git a/packages/aila/src/utils/rag/parseSubjects.ts b/packages/aila/src/utils/rag/parseSubjects.ts new file mode 100644 index 000000000..dbece7719 --- /dev/null +++ b/packages/aila/src/utils/rag/parseSubjects.ts @@ -0,0 +1,17 @@ +const subjectMap: Record = { + science: ["biology", "chemistry", "physics", "science", "combined-science"], + biology: ["biology", "science", "combined-science"], + chemistry: ["chemistry", "science", "combined-science"], + physics: ["physics", "science", "combined-science"], + "combined-science": [ + "combined-science", + "science", + "biology", + "chemistry", + "physics", + ], +}; + +export function parseSubjects(subject: string): string[] { + return subjectMap[subject] ?? [subject]; +} diff --git a/packages/core/src/models/ragLessonPlans.ts b/packages/core/src/models/ragLessonPlans.ts index 51baa0cd1..2b08b8a97 100644 --- a/packages/core/src/models/ragLessonPlans.ts +++ b/packages/core/src/models/ragLessonPlans.ts @@ -1,6 +1,17 @@ import type { PrismaClientWithAccelerate } from "@oakai/db"; import { aiLogger } from "@oakai/logger"; import type OpenAI from "openai"; +import { uniqBy } from "remeda"; + +import type { CompletedLessonPlan } from "../../../aila/src/protocol/schema"; + +type RagLessonPlanResult = { + rag_lesson_plan_id: string; + lesson_plan: CompletedLessonPlan; + key: string; + value_text: string; + distance: number; +}; const log = aiLogger("rag"); /** @@ -12,7 +23,21 @@ export class RagLessonPlans { private readonly openai: OpenAI, ) {} - async getRelevantLessonPlans({ title }: { title: string }): Promise<{}[]> { + async getRelevantLessonPlans({ + title, + keyStageSlugs, + subjectSlugs, + }: { + title: string; + keyStageSlugs: string[] | null; + subjectSlugs: string[] | null; + }): Promise { + if (!keyStageSlugs?.length) { + throw new Error("No key stages provided"); + } + if (!subjectSlugs?.length) { + throw new Error("No subjects provided"); + } const embedding = await this.openai.embeddings.create({ model: "text-embedding-3-large", dimensions: 256, @@ -24,33 +49,31 @@ export class RagLessonPlans { const limit = 50; const startAt = new Date(); - log.info(`Fetching relevant lesson plans for ${title}`); - let results: { rag_lesson_plan_id: string }[] = []; - results = await this.prisma.$queryRaw` - SELECT rag_lesson_plan_id, key, value_text, embedding <-> ${queryEmbedding}::vector - FROM rag.rag_lesson_plan_parts + log.info( + `Fetching relevant lesson plans for ${title}, in ${keyStageSlugs} and ${subjectSlugs}`, + ); + + const results = await this.prisma.$queryRaw` + SELECT rag_lesson_plan_id, lesson_plan, key, value_text, embedding <-> ${queryEmbedding}::vector as distance + FROM rag.rag_lesson_plan_parts JOIN rag.rag_lesson_plans ON rag_lesson_plan_id = rag_lesson_plans.id + WHERE rag_lesson_plans.is_published = true + AND key_stage_slug IN (${keyStageSlugs.join(",")}) + AND subject_slug IN (${subjectSlugs.join(",")}) ORDER BY embedding <-> ${queryEmbedding}::vector LIMIT ${limit}; `; + log.info(results.map((r) => r.lesson_plan.title).join(",\n")); + const endAt = new Date(); log.info( `Fetched ${results.length} lesson plans in ${endAt.getTime() - startAt.getTime()}ms`, ); - log.info(JSON.stringify(results, null, 2)); - const all = await this.prisma.ragLessonPlan.findMany({ - where: { - id: { - in: results.map((r) => r.rag_lesson_plan_id), - }, - }, - select: { - oakLessonSlug: true, - }, - }); - log.info(all); + const uniqueLessonPlans = uniqBy(results, (r) => r.rag_lesson_plan_id); + + log.info(`Unique lesson plans: ${uniqueLessonPlans.length}`); - return []; + return uniqueLessonPlans; } } diff --git a/packages/core/src/prompts/lesson-assistant/parts/interactingWithTheUser.ts b/packages/core/src/prompts/lesson-assistant/parts/interactingWithTheUser.ts index 798efa717..f592bf2d7 100644 --- a/packages/core/src/prompts/lesson-assistant/parts/interactingWithTheUser.ts +++ b/packages/core/src/prompts/lesson-assistant/parts/interactingWithTheUser.ts @@ -49,6 +49,8 @@ These Oak lessons might be relevant: 1. Introduction to the Periodic Table 2. Chemical Reactions and Equations 3. The Structure of the Atom +4. The Mole Concept +5. Acids, Bases and Salts \n To base your lesson on one of these existing Oak lessons, type the lesson number. Tap **Continue** to start from scratch. END OF EXAMPLE RESPONSE`, diff --git a/packages/db/package.json b/packages/db/package.json index 961782624..6ff431295 100644 --- a/packages/db/package.json +++ b/packages/db/package.json @@ -23,6 +23,8 @@ "db-migrate": "pnpm with-env prisma migrate dev", "db-migrate:dev": "pnpm with-env prisma migrate dev", "db-migrate:status": "pnpm with-env prisma migrate status", + "db-migrate:status:stg": "DB_ENV=stg doppler run --config stg -- prisma migrate status", + "db-migrate:status:prd": "DB_ENV=prd doppler run --config prd -- prisma migrate status", "db-migrate-resolve-applied:prd": "doppler run --config prd -- prisma migrate resolve --applied", "db-migrate-resolve-applied:stg": "doppler run --config stg -- prisma migrate resolve --applied", "db-migrate-resolve-rolled-back:prd": "doppler run --config prd -- prisma migrate resolve --rolled-back", diff --git a/packages/db/prisma/additions/README.md b/packages/db/prisma/additions/README.md new file mode 100644 index 000000000..b6b8fdb4f --- /dev/null +++ b/packages/db/prisma/additions/README.md @@ -0,0 +1,3 @@ +# Schema additions + +This directory houses SQL additions to the Prisma schema where the Prisma schema is not expressive enough to capture the desired schema. diff --git a/packages/db/prisma/additions/rag_lesson_plans_ivfflat_index.sql b/packages/db/prisma/additions/rag_lesson_plans_ivfflat_index.sql new file mode 100644 index 000000000..08e104a18 --- /dev/null +++ b/packages/db/prisma/additions/rag_lesson_plans_ivfflat_index.sql @@ -0,0 +1,4 @@ +CREATE INDEX IF NOT EXISTS idx_rag_lesson_plan_parts_embedding_ann +ON rag.rag_lesson_plan_parts +USING ivfflat (embedding vector_cosine_ops) +WITH (lists = 100); \ No newline at end of file diff --git a/packages/db/prisma/additions/rag_lesson_plans_unique_slug_index.sql b/packages/db/prisma/additions/rag_lesson_plans_unique_slug_index.sql new file mode 100644 index 000000000..994d8d499 --- /dev/null +++ b/packages/db/prisma/additions/rag_lesson_plans_unique_slug_index.sql @@ -0,0 +1,3 @@ +CREATE UNIQUE INDEX IF NOT EXISTS idx_rag_lesson_plans_unique_published_oak_lesson_slug +ON rag.rag_lesson_plans (oak_lesson_slug) +WHERE is_published = TRUE; \ No newline at end of file diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index cbf51ac7a..aa9b59a59 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -1117,17 +1117,21 @@ model IngestError { } model RagLessonPlan { - id String @id @default(cuid()) - oakLessonId Int? @map("oak_lesson_id") - oakLessonSlug String @map("oak_lesson_slug") - ingestLessonId String? @map("ingest_lesson_id") - lessonPlan Json @map("lesson_plan") @db.JsonB - subjectSlug String @map("subject_slug") - keyStageSlug String @map("key_stage_slug") - createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") + id String @id @default(cuid()) + oakLessonId Int? @map("oak_lesson_id") + oakLessonSlug String @map("oak_lesson_slug") + ingestLessonId String? @map("ingest_lesson_id") + lessonPlan Json @map("lesson_plan") @db.JsonB + subjectSlug String @map("subject_slug") + keyStageSlug String @map("key_stage_slug") + isPublished Boolean @default(false) @map("is_published") + createdAt DateTime @default(now()) @map("created_at") + updatedAt DateTime @updatedAt @map("updated_at") ragLessonPlanParts RagLessonPlanPart[] + // The following index is not supported by prisma so is applied manually in ./additions/rag_lesson_plans_unique_slug_index.sql + // @@index([oakLessonSlug], name: "unique_published_oak_lesson_slug", dbIndex: false) @db.PartialIndex("is_published = TRUE") + @@index([isPublished, keyStageSlug, subjectSlug], name: "idx_rag_lesson_plans_published_key_stage_subject") @@map("rag_lesson_plans") @@schema("rag") } diff --git a/packages/ingest/src/_data/.gitignore b/packages/ingest/src/_data/.gitignore index 103f2953d..5f18321f9 100644 --- a/packages/ingest/src/_data/.gitignore +++ b/packages/ingest/src/_data/.gitignore @@ -1,3 +1,4 @@ *.jsonl *.json -*.csv \ No newline at end of file +*.csv +* \ No newline at end of file diff --git a/packages/ingest/src/config/ingestConfig.ts b/packages/ingest/src/config/ingestConfig.ts index 267d27ab8..f3ca17c23 100644 --- a/packages/ingest/src/config/ingestConfig.ts +++ b/packages/ingest/src/config/ingestConfig.ts @@ -20,5 +20,7 @@ export const IngestConfigSchema = z.object({ filePath: z.string(), }), ]), + title: z.string().optional(), + description: z.string().optional(), }); export type IngestConfig = z.infer; diff --git a/packages/ingest/src/generate-lesson-plans/parseBatchLessonPlan.ts b/packages/ingest/src/generate-lesson-plans/parseBatchLessonPlan.ts index 852733602..b53539cc6 100644 --- a/packages/ingest/src/generate-lesson-plans/parseBatchLessonPlan.ts +++ b/packages/ingest/src/generate-lesson-plans/parseBatchLessonPlan.ts @@ -2,6 +2,7 @@ import { CompletedLessonPlanSchema } from "@oakai/aila/src/protocol/schema"; import { IngestError } from "../IngestError"; import { CompletionBatchResponseSchema } from "../zod-schema/zodSchema"; +import { parseKeyStage } from "./parseKeyStage"; export function parseBatchLessonPlan(line: unknown) { let result; @@ -37,6 +38,8 @@ export function parseBatchLessonPlan(line: unknown) { lessonPlan = CompletedLessonPlanSchema.parse( JSON.parse(maybeLessonPlanString), ); + + lessonPlan.keyStage = parseKeyStage(lessonPlan.keyStage); } catch (cause) { throw new IngestError("Failed to parse lesson plan", { cause, diff --git a/packages/ingest/src/generate-lesson-plans/parseKeyStage.ts b/packages/ingest/src/generate-lesson-plans/parseKeyStage.ts new file mode 100644 index 000000000..9daa83aaa --- /dev/null +++ b/packages/ingest/src/generate-lesson-plans/parseKeyStage.ts @@ -0,0 +1,20 @@ +const keyStageMap: Record = { + 1: "ks1", + 2: "ks2", + 3: "ks3", + 4: "ks4", + 5: "ks5", + keystage1: "ks1", + keystage2: "ks2", + keystage3: "ks3", + keystage4: "ks4", + keystage5: "ks5", + eyfs: "early-years-foundation-stage", +}; + +export function parseKeyStage(maybeKeyStage: string): string { + maybeKeyStage = maybeKeyStage.toLowerCase().replace(/[^a-z0-9]/g, ""); + const keyStageSlug = keyStageMap[maybeKeyStage]; + + return keyStageSlug ?? maybeKeyStage; +} From f0d846996984de94c64e2feb8e9eeb711efcf306 Mon Sep 17 00:00:00 2001 From: mantagen Date: Wed, 11 Dec 2024 15:13:24 +0000 Subject: [PATCH 07/13] move rag into own package --- .../rag.fixtures.json => rag/fixtures.json} | 0 .../{aila/src/features => }/rag/index.test.ts | 2 +- packages/rag/index.ts | 19 +++++++++ packages/rag/jest.config.mjs | 40 +++++++++++++++++++ packages/rag/package.json | 38 ++++++++++++++++++ 5 files changed, 98 insertions(+), 1 deletion(-) rename packages/{aila/src/features/rag/rag.fixtures.json => rag/fixtures.json} (100%) rename packages/{aila/src/features => }/rag/index.test.ts (98%) create mode 100644 packages/rag/index.ts create mode 100644 packages/rag/jest.config.mjs create mode 100644 packages/rag/package.json diff --git a/packages/aila/src/features/rag/rag.fixtures.json b/packages/rag/fixtures.json similarity index 100% rename from packages/aila/src/features/rag/rag.fixtures.json rename to packages/rag/fixtures.json diff --git a/packages/aila/src/features/rag/index.test.ts b/packages/rag/index.test.ts similarity index 98% rename from packages/aila/src/features/rag/index.test.ts rename to packages/rag/index.test.ts index 8e5a6f6e3..a9277882b 100644 --- a/packages/aila/src/features/rag/index.test.ts +++ b/packages/rag/index.test.ts @@ -7,7 +7,7 @@ import { prisma } from "@oakai/db"; // import OpenAI from "openai"; -import ragFixtures from "./rag.fixtures.json"; +import ragFixtures from "./fixtures.json"; // const openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); async function pgVectorSearch({ diff --git a/packages/rag/index.ts b/packages/rag/index.ts new file mode 100644 index 000000000..36c234e33 --- /dev/null +++ b/packages/rag/index.ts @@ -0,0 +1,19 @@ +import { prisma } from "@oakai/db"; +import { aiLogger } from "@oakai/logger"; + +type RagLessonPlan = { + id: string; + oakLessonId: number | null; + oakLessonSlug: string; + ingestLessonId: string | null; + lessonPlan: unknown; + subjectSlug: string; + keyStageSlug: string; + isPublished: boolean; + createdAt: Date; + updatedAt: Date; +}; + +export async function getRagLessonPlans(): Promise { + return await prisma.ragLessonPlan.findMany(); +} diff --git a/packages/rag/jest.config.mjs b/packages/rag/jest.config.mjs new file mode 100644 index 000000000..125c98f56 --- /dev/null +++ b/packages/rag/jest.config.mjs @@ -0,0 +1,40 @@ +import { readFile } from "fs/promises"; +import { pathsToModuleNameMapper } from "ts-jest"; + +const tsconfig = JSON.parse( + await readFile(new URL("./tsconfig.test.json", import.meta.url)), +); + +/** @type {import('ts-jest').JestConfigWithTsJest} */ +const config = { + transform: { + "^.+\\.tsx?$": [ + "ts-jest", + { + tsconfig: "tsconfig.test.json", + useESM: true, + isolatedModules: true, + }, + ], + }, + preset: "ts-jest/presets/default-esm", + moduleNameMapper: { + ...pathsToModuleNameMapper(tsconfig.compilerOptions.paths, { + prefix: "/", + }), + "^(\\.{1,2}/.*)\\.js$": "$1", + }, + extensionsToTreatAsEsm: [".ts"], + testEnvironment: "setup-polly-jest/jest-environment-node", + testMatch: ["**/*.test.ts"], + moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"], + rootDir: ".", + resetMocks: true, + collectCoverage: + process.env.CI === "true" || process.env.COLLECT_TEST_COVERAGE === "true", + coverageReporters: ["lcov", "text"], + collectCoverageFrom: ["src/**/*.{ts,tsx,js,jsx}"], + coverageDirectory: "coverage", +}; + +export default config; diff --git a/packages/rag/package.json b/packages/rag/package.json new file mode 100644 index 000000000..ccf733028 --- /dev/null +++ b/packages/rag/package.json @@ -0,0 +1,38 @@ +{ + "name": "@oakai/rag", + "version": "1.0.0", + "description": "", + "keywords": [], + "license": "ISC", + "author": "", + "main": "./index.ts", + "eslintConfig": { + "extends": "@oakai/eslint-config", + "parserOptions": { + "project": "./tsconfig.json" + }, + "rules": {} + }, + "scripts": { + "lint": "eslint .", + "type-check": "tsc --noEmit", + "with-env": "dotenv -e ../../.env --", + "test": "pnpm with-env jest --colors --config jest.config.mjs" + }, + "prettier": "@oakai/prettier-config", + "dependencies": { + "@oakai/aila": "*", + "@oakai/core": "*", + "@oakai/db": "*", + "@oakai/logger": "*", + "zod": "3.23.8" + }, + "devDependencies": { + "@oakai/eslint-config": "*", + "@oakai/prettier-config": "*", + "@types/jest": "^29.5.14", + "jest": "^29.7.0", + "ts-jest": "^29.2.5" + }, + "type": "module" +} From ce1cd638b2ad28d21504dbaf8389445de42f272b Mon Sep 17 00:00:00 2001 From: mantagen Date: Thu, 12 Dec 2024 10:38:10 +0000 Subject: [PATCH 08/13] split out functions into own files and add tests --- packages/aila/package.json | 1 + .../aila/src/utils/rag/fetchRagContent.ts | 26 +--- packages/core/src/models/ragLessonPlans.ts | 79 ---------- packages/rag/index.ts | 61 ++++++-- packages/rag/lib/embedding.test.ts | 60 ++++++++ packages/rag/lib/embedding.ts | 24 +++ packages/rag/lib/rerank.test.ts | 56 +++++++ packages/rag/lib/rerank.ts | 37 +++++ packages/rag/lib/search.test.ts | 138 ++++++++++++++++++ packages/rag/lib/search.ts | 67 +++++++++ packages/rag/package.json | 12 +- packages/rag/tsconfig.test.json | 21 +++ packages/rag/types.ts | 14 ++ pnpm-lock.yaml | 131 ++++++++++++++--- 14 files changed, 581 insertions(+), 146 deletions(-) delete mode 100644 packages/core/src/models/ragLessonPlans.ts create mode 100644 packages/rag/lib/embedding.test.ts create mode 100644 packages/rag/lib/embedding.ts create mode 100644 packages/rag/lib/rerank.test.ts create mode 100644 packages/rag/lib/rerank.ts create mode 100644 packages/rag/lib/search.test.ts create mode 100644 packages/rag/lib/search.ts create mode 100644 packages/rag/tsconfig.test.json create mode 100644 packages/rag/types.ts diff --git a/packages/aila/package.json b/packages/aila/package.json index f6ecfd2e6..70aa210ee 100644 --- a/packages/aila/package.json +++ b/packages/aila/package.json @@ -22,6 +22,7 @@ "@oakai/db": "*", "@oakai/exports": "*", "@oakai/logger": "*", + "@oakai/rag": "*", "@sentry/nextjs": "^8.35.0", "@vercel/kv": "^0.2.2", "ai": "^3.3.26", diff --git a/packages/aila/src/utils/rag/fetchRagContent.ts b/packages/aila/src/utils/rag/fetchRagContent.ts index b0599db8f..bd199f76d 100644 --- a/packages/aila/src/utils/rag/fetchRagContent.ts +++ b/packages/aila/src/utils/rag/fetchRagContent.ts @@ -1,6 +1,6 @@ -import { RagLessonPlans } from "@oakai/core/src/models/ragLessonPlans"; import { RAG } from "@oakai/core/src/rag"; import type { PrismaClientWithAccelerate } from "@oakai/db"; +import { getRelevantLessonPlans } from "@oakai/rag"; import OpenAI from "openai"; import { tryWithErrorReporting } from "../../helpers/errorReporting"; @@ -38,11 +38,9 @@ export async function fetchRagContent({ userId?: string; }): Promise { try { - const openAiClient = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); - const ragLessonPlans_ = new RagLessonPlans(prisma, openAiClient); const keyStageSlugs = keyStage ? [parseKeyStage(keyStage)] : null; const subjectSlugs = subject ? parseSubjects(subject) : null; - const results = await ragLessonPlans_.getRelevantLessonPlans({ + const results = await getRelevantLessonPlans({ title, subjectSlugs, keyStageSlugs, @@ -55,24 +53,4 @@ export async function fetchRagContent({ } catch (cause) { throw new Error("Failed to fetch RAG content", { cause }); } - - const rag = new RAG(prisma, { chatId, userId }); - const ragLessonPlans = await tryWithErrorReporting( - () => { - return title && keyStage && subject - ? rag.fetchLessonPlans({ - chatId: id, - title, - keyStage, - subject, - topic, - k, - }) - : []; - }, - "Failed to fetch RAG content", - "info", - ); - - return ragLessonPlans?.map(minifyLessonPlanForRelevantLessons) ?? []; } diff --git a/packages/core/src/models/ragLessonPlans.ts b/packages/core/src/models/ragLessonPlans.ts deleted file mode 100644 index 2b08b8a97..000000000 --- a/packages/core/src/models/ragLessonPlans.ts +++ /dev/null @@ -1,79 +0,0 @@ -import type { PrismaClientWithAccelerate } from "@oakai/db"; -import { aiLogger } from "@oakai/logger"; -import type OpenAI from "openai"; -import { uniqBy } from "remeda"; - -import type { CompletedLessonPlan } from "../../../aila/src/protocol/schema"; - -type RagLessonPlanResult = { - rag_lesson_plan_id: string; - lesson_plan: CompletedLessonPlan; - key: string; - value_text: string; - distance: number; -}; - -const log = aiLogger("rag"); -/** - * - */ -export class RagLessonPlans { - constructor( - private readonly prisma: PrismaClientWithAccelerate, - private readonly openai: OpenAI, - ) {} - - async getRelevantLessonPlans({ - title, - keyStageSlugs, - subjectSlugs, - }: { - title: string; - keyStageSlugs: string[] | null; - subjectSlugs: string[] | null; - }): Promise { - if (!keyStageSlugs?.length) { - throw new Error("No key stages provided"); - } - if (!subjectSlugs?.length) { - throw new Error("No subjects provided"); - } - const embedding = await this.openai.embeddings.create({ - model: "text-embedding-3-large", - dimensions: 256, - input: title, - encoding_format: "float", - }); - - const queryEmbedding = `[${embedding.data[0]?.embedding.join(",")}]`; - const limit = 50; - - const startAt = new Date(); - log.info( - `Fetching relevant lesson plans for ${title}, in ${keyStageSlugs} and ${subjectSlugs}`, - ); - - const results = await this.prisma.$queryRaw` - SELECT rag_lesson_plan_id, lesson_plan, key, value_text, embedding <-> ${queryEmbedding}::vector as distance - FROM rag.rag_lesson_plan_parts JOIN rag.rag_lesson_plans ON rag_lesson_plan_id = rag_lesson_plans.id - WHERE rag_lesson_plans.is_published = true - AND key_stage_slug IN (${keyStageSlugs.join(",")}) - AND subject_slug IN (${subjectSlugs.join(",")}) - ORDER BY embedding <-> ${queryEmbedding}::vector - LIMIT ${limit}; - `; - - log.info(results.map((r) => r.lesson_plan.title).join(",\n")); - - const endAt = new Date(); - log.info( - `Fetched ${results.length} lesson plans in ${endAt.getTime() - startAt.getTime()}ms`, - ); - - const uniqueLessonPlans = uniqBy(results, (r) => r.rag_lesson_plan_id); - - log.info(`Unique lesson plans: ${uniqueLessonPlans.length}`); - - return uniqueLessonPlans; - } -} diff --git a/packages/rag/index.ts b/packages/rag/index.ts index 36c234e33..421f0accc 100644 --- a/packages/rag/index.ts +++ b/packages/rag/index.ts @@ -1,19 +1,52 @@ import { prisma } from "@oakai/db"; import { aiLogger } from "@oakai/logger"; +import { CohereClient } from "cohere-ai"; +import OpenAI from "openai"; -type RagLessonPlan = { - id: string; - oakLessonId: number | null; - oakLessonSlug: string; - ingestLessonId: string | null; - lessonPlan: unknown; - subjectSlug: string; - keyStageSlug: string; - isPublished: boolean; - createdAt: Date; - updatedAt: Date; -}; +import { getEmbedding } from "./lib/embedding"; +import { rerankResults } from "./lib/rerank"; +import { vectorSearch } from "./lib/search"; +import type { RagLessonPlanResult } from "./types"; -export async function getRagLessonPlans(): Promise { - return await prisma.ragLessonPlan.findMany(); +const log = aiLogger("rag"); + +const openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); +const cohereClient = new CohereClient({ + token: process.env.COHERE_API_KEY, +}); + +export async function getRelevantLessonPlans({ + title, + keyStageSlugs, + subjectSlugs, +}: { + title: string; + keyStageSlugs: string[] | null; + subjectSlugs: string[] | null; +}): Promise { + log.info(`Getting embedding for title: ${title}`); + const queryVector = await getEmbedding({ text: title, openai }); + log.info("Got embedding"); + + log.info(`Searching vector database for lesson plans`); + const vectorSearchResults = await vectorSearch({ + prisma, + log, + queryVector, + filters: { + keyStageSlugs, + subjectSlugs, + }, + }); + log.info(`Got ${vectorSearchResults.length} search results`); + + log.info(`Reranking lesson plans`); + const rerankedResults = await rerankResults({ + cohereClient, + query: title, + results: vectorSearchResults, + }); + log.info(`Reranked ${rerankedResults.length} lesson plans`); + + return rerankedResults; } diff --git a/packages/rag/lib/embedding.test.ts b/packages/rag/lib/embedding.test.ts new file mode 100644 index 000000000..b4dd9ce30 --- /dev/null +++ b/packages/rag/lib/embedding.test.ts @@ -0,0 +1,60 @@ +import type { OpenAI } from "openai"; + +import { getEmbedding } from "./embedding"; + +// Mocked OpenAI client +const mockOpenAI = { + embeddings: { + create: jest.fn(), + }, +} as unknown as OpenAI; + +describe("getEmbedding", () => { + const mockText = "This is a test text."; + const mockEmbedding = Array(256).fill(0.1); // Example embedding + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it("returns the embedding for valid input", async () => { + (mockOpenAI.embeddings.create as jest.Mock).mockResolvedValueOnce({ + data: [ + { + embedding: mockEmbedding, + }, + ], + }); + + const result = await getEmbedding({ text: mockText, openai: mockOpenAI }); + + expect(mockOpenAI.embeddings.create).toHaveBeenCalledWith({ + model: "text-embedding-3-large", + dimensions: 256, + input: mockText, + encoding_format: "float", + }); + + expect(result).toEqual(mockEmbedding); + }); + + it("throws an error if no embedding is returned", async () => { + (mockOpenAI.embeddings.create as jest.Mock).mockResolvedValueOnce({ + data: [], + }); + + await expect( + getEmbedding({ text: mockText, openai: mockOpenAI }), + ).rejects.toThrow("Failed to get embedding"); + }); + + it("handles API errors gracefully", async () => { + (mockOpenAI.embeddings.create as jest.Mock).mockRejectedValueOnce( + new Error("API Error"), + ); + + await expect( + getEmbedding({ text: mockText, openai: mockOpenAI }), + ).rejects.toThrow("API Error"); + }); +}); diff --git a/packages/rag/lib/embedding.ts b/packages/rag/lib/embedding.ts new file mode 100644 index 000000000..57669d756 --- /dev/null +++ b/packages/rag/lib/embedding.ts @@ -0,0 +1,24 @@ +import type { OpenAI } from "openai"; + +export async function getEmbedding({ + text, + openai, +}: { + text: string; + openai: OpenAI; +}): Promise { + const response = await openai.embeddings.create({ + model: "text-embedding-3-large", + dimensions: 256, + input: text, + encoding_format: "float", + }); + + const embedding = response.data[0]?.embedding; + + if (!embedding) { + throw new Error("Failed to get embedding"); + } + + return embedding; +} diff --git a/packages/rag/lib/rerank.test.ts b/packages/rag/lib/rerank.test.ts new file mode 100644 index 000000000..6e3f3d101 --- /dev/null +++ b/packages/rag/lib/rerank.test.ts @@ -0,0 +1,56 @@ +import type { CohereClient } from "cohere-ai"; + +import type { RagLessonPlanResult } from "../types"; +import { rerankResults } from "./rerank"; + +describe("rerankResults", () => { + it("should rerank results", async () => { + const cohereClient = { + rerank: jest.fn().mockResolvedValue({ + results: [ + { + index: 0, + relevanceScore: 0.4, + }, + { + index: 1, + relevanceScore: 0.9, + }, + { + index: 2, + }, + ], + }), + }; + const query = "query"; + const results = [ + { + lesson_plan: 0, + }, + { + lesson_plan: 1, + }, + { + lesson_plan: 2, + }, + ] as unknown as RagLessonPlanResult[]; + + const rerankedResults = await rerankResults({ + cohereClient: cohereClient as unknown as CohereClient, + query, + results, + }); + + expect(rerankedResults).toEqual([ + { + lesson_plan: 1, + }, + { + lesson_plan: 2, + }, + { + lesson_plan: 0, + }, + ]); + }); +}); diff --git a/packages/rag/lib/rerank.ts b/packages/rag/lib/rerank.ts new file mode 100644 index 000000000..b5b1504a6 --- /dev/null +++ b/packages/rag/lib/rerank.ts @@ -0,0 +1,37 @@ +import { CohereClient } from "cohere-ai"; +import type { RerankRequest, RerankResponse } from "cohere-ai/api"; + +import type { RagLessonPlanResult } from "../types"; + +export async function rerankResults({ + cohereClient, + query, + results, +}: { + cohereClient: CohereClient; + query: string; + results: RagLessonPlanResult[]; +}): Promise { + const topN = 5; + + const rerankRequest: RerankRequest = { + documents: results.map((result) => JSON.stringify(result.lesson_plan)), + returnDocuments: false, + query, + topN, + }; + + const rerank: RerankResponse = await cohereClient.rerank(rerankRequest); + + const mostRelevantHydrated = rerank.results + .sort((a, b) => b.relevanceScore - a.relevanceScore) + .map((r) => { + const result = results[r.index]; + if (!result) { + throw new Error(`Lesson plan not found at index ${r.index}`); + } + return result; + }); + + return mostRelevantHydrated; +} diff --git a/packages/rag/lib/search.test.ts b/packages/rag/lib/search.test.ts new file mode 100644 index 000000000..46abc442c --- /dev/null +++ b/packages/rag/lib/search.test.ts @@ -0,0 +1,138 @@ +import type { PrismaClientWithAccelerate } from "@oakai/db"; + +import { RagLogger } from "../types"; +import { vectorSearch } from "./search"; + +// Mocked Prisma client +const mockPrisma = { + $queryRaw: jest.fn(), +} as unknown as PrismaClientWithAccelerate; + +// Mock logger +const mockLog: RagLogger = { + info: jest.fn(), + error: jest.fn(), +}; + +describe("vectorSearch", () => { + const mockQueryVector = [0.1, 0.2, 0.3]; + const mockFilters = { + keyStageSlugs: ["ks1", "ks2"], + subjectSlugs: ["math", "science"], + }; + + const mockResults = [ + { + rag_lesson_plan_id: "plan1", + lesson_plan: { + title: "Lesson Plan 1", + subject: "Math", + }, + key: "key1", + value_text: "value1", + distance: 0.5, + }, + { + rag_lesson_plan_id: "plan2", + lesson_plan: { + title: "Lesson Plan 2", + subject: "Science", + }, + key: "key2", + value_text: "value2", + distance: 0.8, + }, + ]; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it("throws an error if no key stages are provided", async () => { + await expect( + vectorSearch({ + prisma: mockPrisma, + log: mockLog, + queryVector: mockQueryVector, + filters: { + keyStageSlugs: [], + subjectSlugs: mockFilters.subjectSlugs, + }, + }), + ).rejects.toThrow("No key stages provided"); + }); + + it("throws an error if no subjects are provided", async () => { + await expect( + vectorSearch({ + prisma: mockPrisma, + log: mockLog, + queryVector: mockQueryVector, + filters: { + keyStageSlugs: mockFilters.keyStageSlugs, + subjectSlugs: [], + }, + }), + ).rejects.toThrow("No subjects provided"); + }); + + it("fetches and returns unique lesson plans", async () => { + (mockPrisma.$queryRaw as jest.Mock).mockResolvedValueOnce(mockResults); + + const result = await vectorSearch({ + prisma: mockPrisma, + log: mockLog, + queryVector: mockQueryVector, + filters: mockFilters, + }); + + expect(mockPrisma.$queryRaw).toHaveBeenCalledWith( + expect.stringContaining("SELECT"), + ); + + expect(mockLog.info).toHaveBeenCalledWith( + expect.stringContaining("Lesson Plan 1,\nLesson Plan 2"), + ); + + expect(result).toEqual([ + { + rag_lesson_plan_id: "plan1", + lesson_plan: { + title: "Lesson Plan 1", + subject: "Math", + }, + key: "key1", + value_text: "value1", + distance: 0.5, + }, + { + rag_lesson_plan_id: "plan2", + lesson_plan: { + title: "Lesson Plan 2", + subject: "Science", + }, + key: "key2", + value_text: "value2", + distance: 0.8, + }, + ]); + }); + + it("logs the processing time and unique lesson plans count", async () => { + (mockPrisma.$queryRaw as jest.Mock).mockResolvedValueOnce(mockResults); + + await vectorSearch({ + prisma: mockPrisma, + log: mockLog, + queryVector: mockQueryVector, + filters: mockFilters, + }); + + expect(mockLog.info).toHaveBeenCalledWith( + expect.stringContaining("Fetched 2 lesson plans"), + ); + expect(mockLog.info).toHaveBeenCalledWith( + expect.stringContaining("Unique lesson plans: 2"), + ); + }); +}); diff --git a/packages/rag/lib/search.ts b/packages/rag/lib/search.ts new file mode 100644 index 000000000..c563a08e9 --- /dev/null +++ b/packages/rag/lib/search.ts @@ -0,0 +1,67 @@ +import type { PrismaClientWithAccelerate } from "@oakai/db"; +import { uniqBy } from "remeda"; +import { z } from "zod"; + +import { CompletedLessonPlanSchema } from "../../aila/src/protocol/schema"; +import type { RagLessonPlanResult, RagLogger } from "../types"; + +const databaseResponseSchema = z.array( + z.object({ + rag_lesson_plan_id: z.string(), + lesson_plan: CompletedLessonPlanSchema, + key: z.string(), + value_text: z.string(), + distance: z.number(), + }), +); + +export async function vectorSearch({ + prisma, + log, + queryVector, + filters, +}: { + prisma: PrismaClientWithAccelerate; + log: RagLogger; + queryVector: number[]; + filters: { + keyStageSlugs: string[] | null; + subjectSlugs: string[] | null; + }; +}): Promise { + const { keyStageSlugs, subjectSlugs } = filters; + if (!keyStageSlugs?.length) { + throw new Error("No key stages provided"); + } + if (!subjectSlugs?.length) { + throw new Error("No subjects provided"); + } + + const queryEmbedding = `[${queryVector.join(",")}]`; + const limit = 50; + const startAt = new Date(); + const response = await prisma.$queryRaw` + SELECT rag_lesson_plan_id, lesson_plan, key, value_text, embedding <-> ${queryEmbedding}::vector as distance + FROM rag.rag_lesson_plan_parts JOIN rag.rag_lesson_plans ON rag_lesson_plan_id = rag_lesson_plans.id + WHERE rag_lesson_plans.is_published = true + AND key_stage_slug IN (${keyStageSlugs.join(",")}) + AND subject_slug IN (${subjectSlugs.join(",")}) + ORDER BY embedding <-> ${queryEmbedding}::vector + LIMIT ${limit}; + `; + + const results = databaseResponseSchema.parse(response); + + log.info(results.map((r) => r.lesson_plan.title).join(",\n")); + + const endAt = new Date(); + log.info( + `Fetched ${results.length} lesson plans in ${endAt.getTime() - startAt.getTime()}ms`, + ); + + const uniqueLessonPlans = uniqBy(results, (r) => r.rag_lesson_plan_id); + + log.info(`Unique lesson plans: ${uniqueLessonPlans.length}`); + + return uniqueLessonPlans; +} diff --git a/packages/rag/package.json b/packages/rag/package.json index ccf733028..760cd994c 100644 --- a/packages/rag/package.json +++ b/packages/rag/package.json @@ -6,13 +6,7 @@ "license": "ISC", "author": "", "main": "./index.ts", - "eslintConfig": { - "extends": "@oakai/eslint-config", - "parserOptions": { - "project": "./tsconfig.json" - }, - "rules": {} - }, + "types": "./index.ts", "scripts": { "lint": "eslint .", "type-check": "tsc --noEmit", @@ -28,11 +22,9 @@ "zod": "3.23.8" }, "devDependencies": { - "@oakai/eslint-config": "*", "@oakai/prettier-config": "*", "@types/jest": "^29.5.14", "jest": "^29.7.0", "ts-jest": "^29.2.5" - }, - "type": "module" + } } diff --git a/packages/rag/tsconfig.test.json b/packages/rag/tsconfig.test.json new file mode 100644 index 000000000..fa491f6fe --- /dev/null +++ b/packages/rag/tsconfig.test.json @@ -0,0 +1,21 @@ +{ + "compilerOptions": { + "target": "esnext", + "module": "esnext", + "lib": ["esnext"], + "outDir": "./dist-tests", + "rootDir": "./src", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "moduleResolution": "node", + "resolveJsonModule": true, + "isolatedModules": true, + "incremental": true, + "noEmit": true, + "types": ["jest", "node"] + }, + "include": ["src/**/*.ts", "tests/**/*.ts", "**/*.test.ts", "src/**/*.d.ts"], + "exclude": ["node_modules"] +} diff --git a/packages/rag/types.ts b/packages/rag/types.ts new file mode 100644 index 000000000..8c8af0642 --- /dev/null +++ b/packages/rag/types.ts @@ -0,0 +1,14 @@ +import type { CompletedLessonPlan } from "../aila/src/protocol/schema"; + +export type RagLessonPlanResult = { + rag_lesson_plan_id: string; + lesson_plan: CompletedLessonPlan; + key: string; + value_text: string; + distance: number; +}; + +export type RagLogger = { + info: (...args: unknown[]) => void; + error: (...args: unknown[]) => void; +}; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9d6d737c7..75f7c021d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -495,6 +495,9 @@ importers: '@oakai/logger': specifier: '*' version: link:../logger + '@oakai/rag': + specifier: '*' + version: link:../rag '@sentry/nextjs': specifier: ^8.35.0 version: 8.35.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.26.0)(@opentelemetry/instrumentation@0.53.0)(@opentelemetry/sdk-trace-base@1.26.0)(next@14.2.18)(react@18.2.0)(webpack@5.93.0) @@ -909,6 +912,37 @@ importers: specifier: ^3.1.1 version: 3.2.5 + packages/rag: + dependencies: + '@oakai/aila': + specifier: '*' + version: link:../aila + '@oakai/core': + specifier: '*' + version: link:../core + '@oakai/db': + specifier: '*' + version: link:../db + '@oakai/logger': + specifier: '*' + version: link:../logger + zod: + specifier: 3.23.8 + version: 3.23.8 + devDependencies: + '@oakai/prettier-config': + specifier: '*' + version: link:../prettier-config + '@types/jest': + specifier: ^29.5.14 + version: 29.5.14 + jest: + specifier: ^29.7.0 + version: 29.7.0(@types/node@18.18.5)(ts-node@10.9.2) + ts-jest: + specifier: ^29.2.5 + version: 29.2.5(@babel/core@7.24.5)(jest@29.7.0)(typescript@5.3.3) + packages: /@aashutoshrathi/word-wrap@1.2.6: @@ -1695,16 +1729,6 @@ packages: '@babel/helper-plugin-utils': 7.24.7 dev: true - /@babel/plugin-syntax-jsx@7.24.1(@babel/core@7.24.5): - resolution: {integrity: sha512-2eCtxZXf+kbkMIsXS4poTvT4Yu5rXiRa+9xGVT56raghjmBTKMpFNc9R4IDiB4emao9eO22Ox7CxuJG7BgExqA==} - engines: {node: '>=6.9.0'} - peerDependencies: - '@babel/core': ^7.0.0-0 - dependencies: - '@babel/core': 7.24.5 - '@babel/helper-plugin-utils': 7.24.7 - dev: true - /@babel/plugin-syntax-jsx@7.24.7(@babel/core@7.24.5): resolution: {integrity: sha512-6ddciUPe/mpMnOKv/U+RSd2vvVy+Yw/JfBB0ZHYjEZt9NLHmCUylNYlsbqCCS1Bffjlb0fCwC9Vqz+sBz6PsiQ==} engines: {node: '>=6.9.0'} @@ -7815,7 +7839,7 @@ packages: read-pkg: 8.1.0 registry-auth-token: 5.0.2 semantic-release: 21.1.2(typescript@5.3.3) - semver: 7.6.2 + semver: 7.6.3 tempy: 3.1.0 dev: false @@ -9253,6 +9277,13 @@ packages: expect: 29.7.0 pretty-format: 29.7.0 + /@types/jest@29.5.14: + resolution: {integrity: sha512-ZN+4sdnLUbo8EVvVc2ao0GFW6oVrQRPn4K2lglySj7APvSrgzxHiNNK99us4WDMi57xxA2yggblIAMNhXOotLQ==} + dependencies: + expect: 29.7.0 + pretty-format: 29.7.0 + dev: true + /@types/js-yaml@4.0.9: resolution: {integrity: sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==} dev: true @@ -10531,7 +10562,6 @@ packages: /async@3.2.6: resolution: {integrity: sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==} - dev: false /asynckit@0.4.0: resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} @@ -10661,7 +10691,7 @@ packages: resolution: {integrity: sha512-Y1IQok9821cC9onCx5otgFfRm7Lm+I+wwxOx738M/WLPZ9Q42m4IG5W0FNX8WLL2gYMZo3JkuXIH2DOpWM+qwA==} engines: {node: '>=8'} dependencies: - '@babel/helper-plugin-utils': 7.24.7 + '@babel/helper-plugin-utils': 7.24.8 '@istanbuljs/load-nyc-config': 1.1.0 '@istanbuljs/schema': 0.1.3 istanbul-lib-instrument: 5.2.1 @@ -12878,6 +12908,14 @@ packages: resolution: {integrity: sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==} dev: true + /ejs@3.1.10: + resolution: {integrity: sha512-UeJmFfOrAQS8OJWPZ4qtgHyWExa088/MtK5UEyoJGFH67cDEXkZSviOiKRCZ4Xij0zxI3JECgYs3oKx+AizQBA==} + engines: {node: '>=0.10.0'} + hasBin: true + dependencies: + jake: 10.9.2 + dev: true + /electron-to-chromium@1.4.676: resolution: {integrity: sha512-uHt4FB8SeYdhcOsj2ix/C39S7sPSNFJpzShjxGOm1KdF4MHyGqGi389+T5cErsodsijojXilYaHIKKqJfqh7uQ==} @@ -13937,6 +13975,12 @@ packages: resolution: {integrity: sha512-P9bmyZ3h/PRG+Nzga+rbdI4OEpNDzAVyy74uVO9ATgzLK6VtAsYybF/+TOCvrc0MO793d6+42lLyZTw7/ArVzA==} dev: false + /filelist@1.0.4: + resolution: {integrity: sha512-w1cEuf3S+DrLCQL7ET6kz+gmlJdbq9J7yXCSjK/OZCPA+qEN1WyF4ZAf0YYJa4/shHJra2t/d/r8SV4Ji+x+8Q==} + dependencies: + minimatch: 5.1.6 + dev: true + /filesize@10.1.4: resolution: {integrity: sha512-ryBwPIIeErmxgPnm6cbESAzXjuEFubs+yKYLBZvg3CaiNcmkJChoOGcBSrZ6IwkMwPABwPpVXE6IlNdGJJrvEg==} engines: {node: '>= 10.4.0'} @@ -15875,7 +15919,7 @@ packages: engines: {node: '>=8'} dependencies: '@babel/core': 7.24.5 - '@babel/parser': 7.24.5 + '@babel/parser': 7.25.3 '@istanbuljs/schema': 0.1.3 istanbul-lib-coverage: 3.2.2 semver: 6.3.1 @@ -15946,6 +15990,17 @@ packages: '@pkgjs/parseargs': 0.11.0 dev: false + /jake@10.9.2: + resolution: {integrity: sha512-2P4SQ0HrLQ+fw6llpLnOaGAvN2Zu6778SJMrCUwns4fOoG9ayrTiZk3VV8sCPkVZF8ab0zksVpS8FDY5pRCNBA==} + engines: {node: '>=10'} + hasBin: true + dependencies: + async: 3.2.6 + chalk: 4.1.2 + filelist: 1.0.4 + minimatch: 3.1.2 + dev: true + /java-properties@1.0.2: resolution: {integrity: sha512-qjdpeo2yKlYTH7nFdK0vbZWuTCesk4o63v5iVOlhMQPfuIZQfW/HI35SjfhA+4qpg36rnFSvUK5b1m+ckIblQQ==} engines: {node: '>= 0.6.0'} @@ -16291,10 +16346,10 @@ packages: engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} dependencies: '@babel/core': 7.24.5 - '@babel/generator': 7.24.5 - '@babel/plugin-syntax-jsx': 7.24.1(@babel/core@7.24.5) + '@babel/generator': 7.25.0 + '@babel/plugin-syntax-jsx': 7.24.7(@babel/core@7.24.5) '@babel/plugin-syntax-typescript': 7.24.7(@babel/core@7.24.5) - '@babel/types': 7.24.5 + '@babel/types': 7.25.2 '@jest/expect-utils': 29.7.0 '@jest/transform': 29.7.0 '@jest/types': 29.6.3 @@ -19155,7 +19210,7 @@ packages: got: 12.6.1 registry-auth-token: 5.0.2 registry-url: 6.0.1 - semver: 7.6.2 + semver: 7.6.3 /pako@1.0.11: resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==} @@ -21063,7 +21118,7 @@ packages: resolution: {integrity: sha512-0Ju4+6A8iOnpL/Thra7dZsSlOHYAHIeMxfhWQRI1/VLcT3WDBZKKtQt/QkBOsiIN9ZpuvHE6cGZ0x4glCMmfiA==} engines: {node: '>=12'} dependencies: - semver: 7.6.2 + semver: 7.6.3 /semver-regex@4.0.5: resolution: {integrity: sha512-hunMQrEy1T6Jr2uEVjrAIqjwWcQTgOAcIM52C8MY1EZSD3DDNft04XzvYKPqjED65bNVVko0YI38nYeEHCX3yw==} @@ -22472,6 +22527,44 @@ packages: yargs-parser: 21.1.1 dev: true + /ts-jest@29.2.5(@babel/core@7.24.5)(jest@29.7.0)(typescript@5.3.3): + resolution: {integrity: sha512-KD8zB2aAZrcKIdGk4OwpJggeLcH1FgrICqDSROWqlnJXGCXK4Mn6FcdK2B6670Xr73lHMG1kHw8R87A0ecZ+vA==} + engines: {node: ^14.15.0 || ^16.10.0 || ^18.0.0 || >=20.0.0} + hasBin: true + peerDependencies: + '@babel/core': '>=7.0.0-beta.0 <8' + '@jest/transform': ^29.0.0 + '@jest/types': ^29.0.0 + babel-jest: ^29.0.0 + esbuild: '*' + jest: ^29.0.0 + typescript: '>=4.3 <6' + peerDependenciesMeta: + '@babel/core': + optional: true + '@jest/transform': + optional: true + '@jest/types': + optional: true + babel-jest: + optional: true + esbuild: + optional: true + dependencies: + '@babel/core': 7.24.5 + bs-logger: 0.2.6 + ejs: 3.1.10 + fast-json-stable-stringify: 2.1.0 + jest: 29.7.0(@types/node@18.18.5)(ts-node@10.9.2) + jest-util: 29.7.0 + json5: 2.2.3 + lodash.memoize: 4.1.2 + make-error: 1.3.6 + semver: 7.6.3 + typescript: 5.3.3 + yargs-parser: 21.1.1 + dev: true + /ts-log@2.2.5: resolution: {integrity: sha512-PGcnJoTBnVGy6yYNFxWVNkdcAuAMstvutN9MgDJIV6L0oG8fB+ZNNy1T+wJzah8RPGor1mZuPQkVfXNDpy9eHA==} dev: true From fe64dbd7190efe7d82dc31223e4cce7e24068b4a Mon Sep 17 00:00:00 2001 From: mantagen Date: Thu, 12 Dec 2024 12:43:09 +0000 Subject: [PATCH 09/13] fix search tests and add faker zod lib --- packages/aila/src/constants.ts | 2 - packages/rag/{index.test.ts => index.tst.ts} | 0 packages/rag/jest.config.js | 24 ++++++++++ packages/rag/jest.config.mjs | 40 ----------------- packages/rag/lib/embedding.ts | 5 ++- packages/rag/lib/rerank.test.ts | 3 +- packages/rag/lib/rerank.ts | 2 +- packages/rag/lib/search.test.ts | 33 +++++--------- packages/rag/package.json | 6 ++- packages/rag/tsconfig.json | 4 ++ packages/rag/tsconfig.test.json | 2 +- packages/rag/types.ts | 4 +- pnpm-lock.yaml | 46 ++++++++++++++++++-- 13 files changed, 96 insertions(+), 75 deletions(-) rename packages/rag/{index.test.ts => index.tst.ts} (100%) create mode 100644 packages/rag/jest.config.js delete mode 100644 packages/rag/jest.config.mjs create mode 100644 packages/rag/tsconfig.json diff --git a/packages/aila/src/constants.ts b/packages/aila/src/constants.ts index 57032eb3a..fbcc27ab5 100644 --- a/packages/aila/src/constants.ts +++ b/packages/aila/src/constants.ts @@ -5,8 +5,6 @@ export const DEFAULT_MODERATION_MODEL: OpenAI.Chat.ChatModel = "gpt-4o-2024-08-06"; export const DEFAULT_CATEGORISE_MODEL: OpenAI.Chat.ChatModel = "gpt-4o-2024-08-06"; -export const DEFAULT_EMBEDDING_MODEL: OpenAI.Embeddings.EmbeddingCreateParams["model"] = - "text-embedding-3-large"; export const DEFAULT_TEMPERATURE = 0.7; export const DEFAULT_MODERATION_TEMPERATURE = 0.7; export const DEFAULT_RAG_LESSON_PLANS = 5; diff --git a/packages/rag/index.test.ts b/packages/rag/index.tst.ts similarity index 100% rename from packages/rag/index.test.ts rename to packages/rag/index.tst.ts diff --git a/packages/rag/jest.config.js b/packages/rag/jest.config.js new file mode 100644 index 000000000..859dd1798 --- /dev/null +++ b/packages/rag/jest.config.js @@ -0,0 +1,24 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +const config = { + transform: { + "^.+\\.tsx?$": [ + "ts-jest", + { + tsconfig: "tsconfig.test.json", + useESM: true, + }, + ], + }, + preset: "ts-jest/presets/default-esm", + moduleNameMapper: { + "^(\\.{1,2}/.*)\\.js$": "$1", + }, + extensionsToTreatAsEsm: [".ts"], + testEnvironment: "setup-polly-jest/jest-environment-node", + testMatch: ["**/*.test.ts"], + moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"], + rootDir: ".", + resetMocks: true, +}; + +module.exports = config; diff --git a/packages/rag/jest.config.mjs b/packages/rag/jest.config.mjs deleted file mode 100644 index 125c98f56..000000000 --- a/packages/rag/jest.config.mjs +++ /dev/null @@ -1,40 +0,0 @@ -import { readFile } from "fs/promises"; -import { pathsToModuleNameMapper } from "ts-jest"; - -const tsconfig = JSON.parse( - await readFile(new URL("./tsconfig.test.json", import.meta.url)), -); - -/** @type {import('ts-jest').JestConfigWithTsJest} */ -const config = { - transform: { - "^.+\\.tsx?$": [ - "ts-jest", - { - tsconfig: "tsconfig.test.json", - useESM: true, - isolatedModules: true, - }, - ], - }, - preset: "ts-jest/presets/default-esm", - moduleNameMapper: { - ...pathsToModuleNameMapper(tsconfig.compilerOptions.paths, { - prefix: "/", - }), - "^(\\.{1,2}/.*)\\.js$": "$1", - }, - extensionsToTreatAsEsm: [".ts"], - testEnvironment: "setup-polly-jest/jest-environment-node", - testMatch: ["**/*.test.ts"], - moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"], - rootDir: ".", - resetMocks: true, - collectCoverage: - process.env.CI === "true" || process.env.COLLECT_TEST_COVERAGE === "true", - coverageReporters: ["lcov", "text"], - collectCoverageFrom: ["src/**/*.{ts,tsx,js,jsx}"], - coverageDirectory: "coverage", -}; - -export default config; diff --git a/packages/rag/lib/embedding.ts b/packages/rag/lib/embedding.ts index 57669d756..1be00fbb8 100644 --- a/packages/rag/lib/embedding.ts +++ b/packages/rag/lib/embedding.ts @@ -1,5 +1,8 @@ import type { OpenAI } from "openai"; +const DEFAULT_EMBEDDING_MODEL: OpenAI.Embeddings.EmbeddingCreateParams["model"] = + "text-embedding-3-large"; + export async function getEmbedding({ text, openai, @@ -8,7 +11,7 @@ export async function getEmbedding({ openai: OpenAI; }): Promise { const response = await openai.embeddings.create({ - model: "text-embedding-3-large", + model: DEFAULT_EMBEDDING_MODEL, dimensions: 256, input: text, encoding_format: "float", diff --git a/packages/rag/lib/rerank.test.ts b/packages/rag/lib/rerank.test.ts index 6e3f3d101..426d088e4 100644 --- a/packages/rag/lib/rerank.test.ts +++ b/packages/rag/lib/rerank.test.ts @@ -4,7 +4,7 @@ import type { RagLessonPlanResult } from "../types"; import { rerankResults } from "./rerank"; describe("rerankResults", () => { - it("should rerank results", async () => { + it("should rerank results based on cohere's relevanceScore", async () => { const cohereClient = { rerank: jest.fn().mockResolvedValue({ results: [ @@ -18,6 +18,7 @@ describe("rerankResults", () => { }, { index: 2, + relevanceScore: 0.5, }, ], }), diff --git a/packages/rag/lib/rerank.ts b/packages/rag/lib/rerank.ts index b5b1504a6..00b10c50b 100644 --- a/packages/rag/lib/rerank.ts +++ b/packages/rag/lib/rerank.ts @@ -1,4 +1,4 @@ -import { CohereClient } from "cohere-ai"; +import type { CohereClient } from "cohere-ai"; import type { RerankRequest, RerankResponse } from "cohere-ai/api"; import type { RagLessonPlanResult } from "../types"; diff --git a/packages/rag/lib/search.test.ts b/packages/rag/lib/search.test.ts index 46abc442c..faf6a3b46 100644 --- a/packages/rag/lib/search.test.ts +++ b/packages/rag/lib/search.test.ts @@ -1,8 +1,13 @@ +import { generateMock } from "@anatine/zod-mock"; +import { CompletedLessonPlanSchema } from "@oakai/aila/src/protocol/schema"; import type { PrismaClientWithAccelerate } from "@oakai/db"; -import { RagLogger } from "../types"; +import type { RagLogger } from "../types"; import { vectorSearch } from "./search"; +const mockLessonPlan1 = generateMock(CompletedLessonPlanSchema); +const mockLessonPlan2 = generateMock(CompletedLessonPlanSchema); + // Mocked Prisma client const mockPrisma = { $queryRaw: jest.fn(), @@ -24,20 +29,14 @@ describe("vectorSearch", () => { const mockResults = [ { rag_lesson_plan_id: "plan1", - lesson_plan: { - title: "Lesson Plan 1", - subject: "Math", - }, + lesson_plan: mockLessonPlan1, key: "key1", value_text: "value1", distance: 0.5, }, { rag_lesson_plan_id: "plan2", - lesson_plan: { - title: "Lesson Plan 2", - subject: "Science", - }, + lesson_plan: mockLessonPlan2, key: "key2", value_text: "value2", distance: 0.8, @@ -86,31 +85,21 @@ describe("vectorSearch", () => { filters: mockFilters, }); - expect(mockPrisma.$queryRaw).toHaveBeenCalledWith( + expect((mockPrisma.$queryRaw as jest.Mock).mock.calls[0][0][0]).toEqual( expect.stringContaining("SELECT"), ); - expect(mockLog.info).toHaveBeenCalledWith( - expect.stringContaining("Lesson Plan 1,\nLesson Plan 2"), - ); - expect(result).toEqual([ { rag_lesson_plan_id: "plan1", - lesson_plan: { - title: "Lesson Plan 1", - subject: "Math", - }, + lesson_plan: mockLessonPlan1, key: "key1", value_text: "value1", distance: 0.5, }, { rag_lesson_plan_id: "plan2", - lesson_plan: { - title: "Lesson Plan 2", - subject: "Science", - }, + lesson_plan: mockLessonPlan2, key: "key2", value_text: "value2", distance: 0.8, diff --git a/packages/rag/package.json b/packages/rag/package.json index 760cd994c..890d36b97 100644 --- a/packages/rag/package.json +++ b/packages/rag/package.json @@ -11,18 +11,20 @@ "lint": "eslint .", "type-check": "tsc --noEmit", "with-env": "dotenv -e ../../.env --", - "test": "pnpm with-env jest --colors --config jest.config.mjs" + "test": "pnpm with-env jest --colors --config jest.config.js" }, "prettier": "@oakai/prettier-config", "dependencies": { + "@anatine/zod-mock": "^3.13.4", + "@faker-js/faker": "^9.3.0", "@oakai/aila": "*", "@oakai/core": "*", "@oakai/db": "*", "@oakai/logger": "*", + "@oakai/prettier-config": "*", "zod": "3.23.8" }, "devDependencies": { - "@oakai/prettier-config": "*", "@types/jest": "^29.5.14", "jest": "^29.7.0", "ts-jest": "^29.2.5" diff --git a/packages/rag/tsconfig.json b/packages/rag/tsconfig.json new file mode 100644 index 000000000..f9c9d7fa2 --- /dev/null +++ b/packages/rag/tsconfig.json @@ -0,0 +1,4 @@ +{ + "extends": "../../tsconfig.json", + "include": ["*.ts", "**/*.ts"] +} diff --git a/packages/rag/tsconfig.test.json b/packages/rag/tsconfig.test.json index fa491f6fe..69fb61e62 100644 --- a/packages/rag/tsconfig.test.json +++ b/packages/rag/tsconfig.test.json @@ -16,6 +16,6 @@ "noEmit": true, "types": ["jest", "node"] }, - "include": ["src/**/*.ts", "tests/**/*.ts", "**/*.test.ts", "src/**/*.d.ts"], + "include": ["**/*.test.ts"], "exclude": ["node_modules"] } diff --git a/packages/rag/types.ts b/packages/rag/types.ts index 8c8af0642..51fb8ae7a 100644 --- a/packages/rag/types.ts +++ b/packages/rag/types.ts @@ -9,6 +9,6 @@ export type RagLessonPlanResult = { }; export type RagLogger = { - info: (...args: unknown[]) => void; - error: (...args: unknown[]) => void; + info: (...args: string[]) => void; + error: (...args: string[]) => void; }; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 75f7c021d..2ba7cd098 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -914,6 +914,12 @@ importers: packages/rag: dependencies: + '@anatine/zod-mock': + specifier: ^3.13.4 + version: 3.13.4(@faker-js/faker@9.3.0)(zod@3.23.8) + '@faker-js/faker': + specifier: ^9.3.0 + version: 9.3.0 '@oakai/aila': specifier: '*' version: link:../aila @@ -926,13 +932,13 @@ importers: '@oakai/logger': specifier: '*' version: link:../logger + '@oakai/prettier-config': + specifier: '*' + version: link:../prettier-config zod: specifier: 3.23.8 version: 3.23.8 devDependencies: - '@oakai/prettier-config': - specifier: '*' - version: link:../prettier-config '@types/jest': specifier: ^29.5.14 version: 29.5.14 @@ -1103,6 +1109,17 @@ packages: '@jridgewell/gen-mapping': 0.3.5 '@jridgewell/trace-mapping': 0.3.25 + /@anatine/zod-mock@3.13.4(@faker-js/faker@9.3.0)(zod@3.23.8): + resolution: {integrity: sha512-yO/KeuyYsEDCTcQ+7CiRuY3dnafMHIZUMok6Ci7aERRCTQ+/XmsiPk/RnMx5wlLmWBTmX9kw+PavbMsjM+sAJA==} + peerDependencies: + '@faker-js/faker': ^7.0.0 || ^8.0.0 + zod: ^3.21.4 + dependencies: + '@faker-js/faker': 9.3.0 + randexp: 0.5.3 + zod: 3.23.8 + dev: false + /@anthropic-ai/sdk@0.6.8: resolution: {integrity: sha512-z4gDFrBf+W2wOVvwA3CA+5bfKOxQhPeXQo7+ITWj3r3XPulIMEasVT0KrD41G+anr5Yc3d2PKvXKB6b1LSon5w==} dependencies: @@ -3384,6 +3401,11 @@ packages: marked: 12.0.2 dev: true + /@faker-js/faker@9.3.0: + resolution: {integrity: sha512-r0tJ3ZOkMd9xsu3VRfqlFR6cz0V/jFYRswAIpC+m/DIfAUXq7g8N7wTAlhSANySXYGKzGryfDXwtwsY8TxEIDw==} + engines: {node: '>=18.0.0', npm: '>=9.0.0'} + dev: false + /@fastify/busboy@2.1.1: resolution: {integrity: sha512-vBZP4NlzfOlerQTnba4aqZoMhE/a9HY7HRqoOPaETQcSQuWEIyZMHGfVu6w9wGtGK5fED5qRs2DteVCjOH60sA==} engines: {node: '>=14'} @@ -12876,6 +12898,11 @@ packages: resolution: {integrity: sha512-ZmdL2rui+eB2YwhsWzjInR8LldtZHGDoQ1ugH85ppHKwpUHL7j7rN0Ti9NCnGiQbhaZ11FpR+7ao1dNsmduNUg==} engines: {node: '>=12'} + /drange@1.1.1: + resolution: {integrity: sha512-pYxfDYpued//QpnLIm4Avk7rsNtAtQkUES2cwAYSvD/wd2pKD71gN2Ebj3e7klzXwjocvE8c5vx/1fxwpqmSxA==} + engines: {node: '>=4'} + dev: false + /dset@3.1.4: resolution: {integrity: sha512-2QF/g9/zTaPDc3BjNcVTGoBbXBgYfMTTceLaYcFJ/W9kggFUkhxD/hMEeuLKbugyef9SqAx8cpgwlIP/jinUTA==} engines: {node: '>=4'} @@ -20202,6 +20229,14 @@ packages: resolution: {integrity: sha512-tEF5I22zJnuclswcZMc8bDIrwRHRzf+NqVEmqg50ShAZMP7MWeR/RGDthfM/p+BlqvF2fXAzpn8i+SJcYD3alw==} dev: false + /randexp@0.5.3: + resolution: {integrity: sha512-U+5l2KrcMNOUPYvazA3h5ekF80FHTUG+87SEAmHZmolh1M+i/WyTCxVzmi+tidIa1tM4BSe8g2Y/D3loWDjj+w==} + engines: {node: '>=4'} + dependencies: + drange: 1.1.1 + ret: 0.2.2 + dev: false + /randombytes@2.1.0: resolution: {integrity: sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==} dependencies: @@ -20878,6 +20913,11 @@ packages: onetime: 5.1.2 signal-exit: 3.0.7 + /ret@0.2.2: + resolution: {integrity: sha512-M0b3YWQs7R3Z917WRQy1HHA7Ba7D8hvZg6UE5mLykJxQVE2ju0IXbGlaHPPlkY+WN7wFP+wUMXmBFA0aV6vYGQ==} + engines: {node: '>=4'} + dev: false + /retry-request@7.0.2: resolution: {integrity: sha512-dUOvLMJ0/JJYEn8NrpOaGNE7X3vpI5XlZS/u0ANjqtcZVKnIxP7IgCFwrKTxENw29emmwug53awKtaMm4i9g5w==} engines: {node: '>=14'} From 1fd47b2e5bbecab8cb27b49c3b6a76838ee3202a Mon Sep 17 00:00:00 2001 From: mantagen Date: Thu, 12 Dec 2024 16:04:08 +0000 Subject: [PATCH 10/13] new rag behind feature flag --- .../builders/AilaLessonPromptBuilder.ts | 94 ++++++++++++++----- packages/aila/src/features/rag/AilaRag.ts | 6 +- packages/aila/src/protocol/schema.ts | 3 +- .../aila/src/utils/rag/fetchRagContent.ts | 37 ++++---- packages/rag/lib/rerank.ts | 2 +- packages/rag/lib/search.test.ts | 2 +- packages/rag/lib/search.ts | 23 +++-- packages/rag/package.json | 1 - packages/rag/types.ts | 10 +- pnpm-lock.yaml | 3 - 10 files changed, 114 insertions(+), 67 deletions(-) diff --git a/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts b/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts index a6142cd71..e86d8fdd4 100644 --- a/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts +++ b/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts @@ -2,6 +2,7 @@ import type { TemplateProps } from "@oakai/core/src/prompts/lesson-assistant"; import { template } from "@oakai/core/src/prompts/lesson-assistant"; import { prisma as globalPrisma } from "@oakai/db/client"; import { aiLogger } from "@oakai/logger"; +import { getRelevantLessonPlans } from "@oakai/rag"; import { DEFAULT_RAG_LESSON_PLANS } from "../../../constants"; import { tryWithErrorReporting } from "../../../helpers/errorReporting"; @@ -12,6 +13,8 @@ import { compressedLessonPlanForRag } from "../../../utils/lessonPlan/compressed import { fetchLessonPlan } from "../../../utils/lessonPlan/fetchLessonPlan"; import type { RagLessonPlan } from "../../../utils/rag/fetchRagContent"; import { fetchRagContent } from "../../../utils/rag/fetchRagContent"; +import { parseKeyStage } from "../../../utils/rag/parseKeyStage"; +import { parseSubjects } from "../../../utils/rag/parseSubjects"; import type { AilaServices } from "../../AilaServices"; import { AilaPromptBuilder } from "../AilaPromptBuilder"; @@ -64,36 +67,75 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder { const { title, subject, keyStage, topic } = this._aila?.lessonPlan ?? {}; - let relevantLessonPlans: RagLessonPlan[] = []; - await tryWithErrorReporting(async () => { - relevantLessonPlans = await fetchRagContent({ - title: title ?? "unknown", - subject, - topic, - keyStage, - id: chatId, - k: - this._aila?.options.numberOfLessonPlansInRag ?? - DEFAULT_RAG_LESSON_PLANS, - prisma: globalPrisma, - chatId, - userId, + const NEW_RAG_ENABLED = true; + + if (NEW_RAG_ENABLED) { + if (!title || !subject || !keyStage) { + log.error( + "Missing title, subject or keyStage, returning empty content", + ); + return { + ragLessonPlans: [], + stringifiedRelevantLessonPlans: noRelevantLessonPlans, + }; + } + + const keyStageSlugs = keyStage ? [parseKeyStage(keyStage)] : null; + const subjectSlugs = subject ? parseSubjects(subject) : null; + + const relevantLessonPlans = await getRelevantLessonPlans({ + title, + keyStageSlugs, + subjectSlugs, }); - }, "Did not fetch RAG content. Continuing"); + const stringifiedRelevantLessonPlans = JSON.stringify( + relevantLessonPlans, + null, + 2, + ); - log.info("Fetched relevant lesson plans", relevantLessonPlans.length); - const stringifiedRelevantLessonPlans = JSON.stringify( - relevantLessonPlans, - null, - 2, - ); + return { + ragLessonPlans: relevantLessonPlans.map((l) => ({ + ...l.lessonPlan, + id: l.ragLessonPlanId, + })), + stringifiedRelevantLessonPlans, + }; + } else { + let relevantLessonPlans: RagLessonPlan[] = []; + await tryWithErrorReporting(async () => { + relevantLessonPlans = await fetchRagContent({ + title: title ?? "unknown", + subject, + topic, + keyStage, + id: chatId, + k: + this._aila?.options.numberOfLessonPlansInRag ?? + DEFAULT_RAG_LESSON_PLANS, + prisma: globalPrisma, + chatId, + userId, + }); + }, "Did not fetch RAG content. Continuing"); - log.info("Got RAG content, length:", stringifiedRelevantLessonPlans.length); + log.info("Fetched relevant lesson plans", relevantLessonPlans.length); + const stringifiedRelevantLessonPlans = JSON.stringify( + relevantLessonPlans, + null, + 2, + ); - return { - ragLessonPlans: relevantLessonPlans, - stringifiedRelevantLessonPlans, - }; + log.info( + "Got RAG content, length:", + stringifiedRelevantLessonPlans.length, + ); + + return { + ragLessonPlans: relevantLessonPlans, + stringifiedRelevantLessonPlans, + }; + } } private systemPrompt( diff --git a/packages/aila/src/features/rag/AilaRag.ts b/packages/aila/src/features/rag/AilaRag.ts index f9b58fb71..5850dd8c6 100644 --- a/packages/aila/src/features/rag/AilaRag.ts +++ b/packages/aila/src/features/rag/AilaRag.ts @@ -12,9 +12,9 @@ import { minifyLessonPlanForRelevantLessons } from "../../utils/lessonPlan/minif const log = aiLogger("aila:rag"); export class AilaRag implements AilaRagFeature { - private _aila: AilaServices; - private _rag: RAG; - private _prisma: PrismaClientWithAccelerate; + private readonly _aila: AilaServices; + private readonly _rag: RAG; + private readonly _prisma: PrismaClientWithAccelerate; constructor({ aila, diff --git a/packages/aila/src/protocol/schema.ts b/packages/aila/src/protocol/schema.ts index 3677b03f4..21a18a6fb 100644 --- a/packages/aila/src/protocol/schema.ts +++ b/packages/aila/src/protocol/schema.ts @@ -588,8 +588,7 @@ export const LessonPlanJsonSchema = zodToJsonSchema( ); const AilaRagRelevantLessonSchema = z.object({ - // @todo add this after next ingest - // oakLessonId: z.number(), + oakLessonId: z.number().nullish(), lessonPlanId: z.string(), title: z.string(), }); diff --git a/packages/aila/src/utils/rag/fetchRagContent.ts b/packages/aila/src/utils/rag/fetchRagContent.ts index bd199f76d..3ffb709dc 100644 --- a/packages/aila/src/utils/rag/fetchRagContent.ts +++ b/packages/aila/src/utils/rag/fetchRagContent.ts @@ -1,13 +1,9 @@ import { RAG } from "@oakai/core/src/rag"; import type { PrismaClientWithAccelerate } from "@oakai/db"; -import { getRelevantLessonPlans } from "@oakai/rag"; -import OpenAI from "openai"; import { tryWithErrorReporting } from "../../helpers/errorReporting"; import type { CompletedLessonPlan } from "../../protocol/schema"; import { minifyLessonPlanForRelevantLessons } from "../lessonPlan/minifyLessonPlanForRelevantLessons"; -import { parseKeyStage } from "./parseKeyStage"; -import { parseSubjects } from "./parseSubjects"; export type RagLessonPlan = Omit< CompletedLessonPlan, @@ -37,20 +33,23 @@ export async function fetchRagContent({ chatId: string; userId?: string; }): Promise { - try { - const keyStageSlugs = keyStage ? [parseKeyStage(keyStage)] : null; - const subjectSlugs = subject ? parseSubjects(subject) : null; - const results = await getRelevantLessonPlans({ - title, - subjectSlugs, - keyStageSlugs, - }); + const rag = new RAG(prisma, { chatId, userId }); + const ragLessonPlans = await tryWithErrorReporting( + () => { + return title && keyStage && subject + ? rag.fetchLessonPlans({ + chatId: id, + title, + keyStage, + subject, + topic, + k, + }) + : []; + }, + "Failed to fetch RAG content", + "info", + ); - return results.map((result) => ({ - id: result.rag_lesson_plan_id, - ...result.lesson_plan, - })); - } catch (cause) { - throw new Error("Failed to fetch RAG content", { cause }); - } + return ragLessonPlans?.map(minifyLessonPlanForRelevantLessons) ?? []; } diff --git a/packages/rag/lib/rerank.ts b/packages/rag/lib/rerank.ts index 00b10c50b..6782be719 100644 --- a/packages/rag/lib/rerank.ts +++ b/packages/rag/lib/rerank.ts @@ -15,7 +15,7 @@ export async function rerankResults({ const topN = 5; const rerankRequest: RerankRequest = { - documents: results.map((result) => JSON.stringify(result.lesson_plan)), + documents: results.map((result) => JSON.stringify(result.lessonPlan)), returnDocuments: false, query, topN, diff --git a/packages/rag/lib/search.test.ts b/packages/rag/lib/search.test.ts index faf6a3b46..30f4a317b 100644 --- a/packages/rag/lib/search.test.ts +++ b/packages/rag/lib/search.test.ts @@ -1,7 +1,7 @@ import { generateMock } from "@anatine/zod-mock"; -import { CompletedLessonPlanSchema } from "@oakai/aila/src/protocol/schema"; import type { PrismaClientWithAccelerate } from "@oakai/db"; +import { CompletedLessonPlanSchema } from "../../aila/src/protocol/schema"; import type { RagLogger } from "../types"; import { vectorSearch } from "./search"; diff --git a/packages/rag/lib/search.ts b/packages/rag/lib/search.ts index c563a08e9..02bcd5cd2 100644 --- a/packages/rag/lib/search.ts +++ b/packages/rag/lib/search.ts @@ -7,10 +7,12 @@ import type { RagLessonPlanResult, RagLogger } from "../types"; const databaseResponseSchema = z.array( z.object({ - rag_lesson_plan_id: z.string(), - lesson_plan: CompletedLessonPlanSchema, - key: z.string(), - value_text: z.string(), + ragLessonPlanId: z.string(), + oakLessonId: z.number().nullable(), + oakLessonSlug: z.string(), + lessonPlan: CompletedLessonPlanSchema, + matchedKey: z.string(), + matchedValue: z.string(), distance: z.number(), }), ); @@ -41,7 +43,14 @@ export async function vectorSearch({ const limit = 50; const startAt = new Date(); const response = await prisma.$queryRaw` - SELECT rag_lesson_plan_id, lesson_plan, key, value_text, embedding <-> ${queryEmbedding}::vector as distance + SELECT + rag_lesson_plan_id as ragLessonPlanId, + oak_lesson_id as oakLessonId, + oak_lesson_slug as oakLessonSlug, + lesson_plan as lessonPlan, + key as matchedKey, + value_text as matchedValue, + embedding <-> ${queryEmbedding}::vector as distance FROM rag.rag_lesson_plan_parts JOIN rag.rag_lesson_plans ON rag_lesson_plan_id = rag_lesson_plans.id WHERE rag_lesson_plans.is_published = true AND key_stage_slug IN (${keyStageSlugs.join(",")}) @@ -52,14 +61,14 @@ export async function vectorSearch({ const results = databaseResponseSchema.parse(response); - log.info(results.map((r) => r.lesson_plan.title).join(",\n")); + log.info(results.map((r) => r.lessonPlan.title).join(",\n")); const endAt = new Date(); log.info( `Fetched ${results.length} lesson plans in ${endAt.getTime() - startAt.getTime()}ms`, ); - const uniqueLessonPlans = uniqBy(results, (r) => r.rag_lesson_plan_id); + const uniqueLessonPlans = uniqBy(results, (r) => r.ragLessonPlanId); log.info(`Unique lesson plans: ${uniqueLessonPlans.length}`); diff --git a/packages/rag/package.json b/packages/rag/package.json index 890d36b97..2f35803eb 100644 --- a/packages/rag/package.json +++ b/packages/rag/package.json @@ -17,7 +17,6 @@ "dependencies": { "@anatine/zod-mock": "^3.13.4", "@faker-js/faker": "^9.3.0", - "@oakai/aila": "*", "@oakai/core": "*", "@oakai/db": "*", "@oakai/logger": "*", diff --git a/packages/rag/types.ts b/packages/rag/types.ts index 51fb8ae7a..d9846aed6 100644 --- a/packages/rag/types.ts +++ b/packages/rag/types.ts @@ -1,10 +1,12 @@ import type { CompletedLessonPlan } from "../aila/src/protocol/schema"; export type RagLessonPlanResult = { - rag_lesson_plan_id: string; - lesson_plan: CompletedLessonPlan; - key: string; - value_text: string; + ragLessonPlanId: string; + oakLessonId: number | null; + oakLessonSlug: string; + lessonPlan: CompletedLessonPlan; + matchedKey: string; + matchedValue: string; distance: number; }; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2ba7cd098..c7ef31143 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -920,9 +920,6 @@ importers: '@faker-js/faker': specifier: ^9.3.0 version: 9.3.0 - '@oakai/aila': - specifier: '*' - version: link:../aila '@oakai/core': specifier: '*' version: link:../core From dba54e4041f281a01d4aca766980c86d7eb0f70a Mon Sep 17 00:00:00 2001 From: mantagen Date: Thu, 12 Dec 2024 17:15:12 +0000 Subject: [PATCH 11/13] camel case sql query response --- packages/rag/index.ts | 4 ++++ packages/rag/lib/search.ts | 12 ++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/packages/rag/index.ts b/packages/rag/index.ts index 421f0accc..3d3ef664f 100644 --- a/packages/rag/index.ts +++ b/packages/rag/index.ts @@ -40,6 +40,10 @@ export async function getRelevantLessonPlans({ }); log.info(`Got ${vectorSearchResults.length} search results`); + if (vectorSearchResults.length === 0) { + return []; + } + log.info(`Reranking lesson plans`); const rerankedResults = await rerankResults({ cohereClient, diff --git a/packages/rag/lib/search.ts b/packages/rag/lib/search.ts index 02bcd5cd2..f7699a37b 100644 --- a/packages/rag/lib/search.ts +++ b/packages/rag/lib/search.ts @@ -44,12 +44,12 @@ export async function vectorSearch({ const startAt = new Date(); const response = await prisma.$queryRaw` SELECT - rag_lesson_plan_id as ragLessonPlanId, - oak_lesson_id as oakLessonId, - oak_lesson_slug as oakLessonSlug, - lesson_plan as lessonPlan, - key as matchedKey, - value_text as matchedValue, + rag_lesson_plan_id as "ragLessonPlanId", + oak_lesson_id as "oakLessonId", + oak_lesson_slug as "oakLessonSlug", + lesson_plan as "lessonPlan", + key as "matchedKey", + value_text as "matchedValue", embedding <-> ${queryEmbedding}::vector as distance FROM rag.rag_lesson_plan_parts JOIN rag.rag_lesson_plans ON rag_lesson_plan_id = rag_lesson_plans.id WHERE rag_lesson_plans.is_published = true From 7485394fcf3429b2212f2cc1917689175a68f706 Mon Sep 17 00:00:00 2001 From: mantagen Date: Wed, 18 Dec 2024 16:24:04 +0000 Subject: [PATCH 12/13] fix publishing step --- packages/ingest/src/index.ts | 8 +- packages/ingest/src/steps/7-publish.ts | 116 ++++++++++++++++++++----- 2 files changed, 102 insertions(+), 22 deletions(-) diff --git a/packages/ingest/src/index.ts b/packages/ingest/src/index.ts index 34d1d84d4..8c79e2135 100644 --- a/packages/ingest/src/index.ts +++ b/packages/ingest/src/index.ts @@ -5,6 +5,7 @@ import { prisma } from "@oakai/db"; import { aiLogger } from "@oakai/logger"; +import { IngestError } from "./IngestError"; import { getLatestIngestId } from "./db-helpers/getLatestIngestId"; import { ingestStart } from "./steps/0-start"; import { captions } from "./steps/1-captions"; @@ -59,6 +60,11 @@ async function main() { } main().catch((error) => { - log.error("Error running command", error); + log.info(error.toString()); + if (error instanceof IngestError) { + log.info("Ingest ID " + error.ingestId); + log.info("Lesson ID " + error.lessonId); + } + log.error("Error running command, see above"); process.exit(1); }); diff --git a/packages/ingest/src/steps/7-publish.ts b/packages/ingest/src/steps/7-publish.ts index 708ee6c77..8f1aca46e 100644 --- a/packages/ingest/src/steps/7-publish.ts +++ b/packages/ingest/src/steps/7-publish.ts @@ -2,11 +2,9 @@ import { LessonPlanSchema, type LooseLessonPlan, } from "@oakai/aila/src/protocol/schema"; -import type { PrismaClientWithAccelerate } from "@oakai/db"; -import type { Prisma } from "@oakai/db"; +import type { Prisma, PrismaClientWithAccelerate } from "@oakai/db"; import { createId } from "@paralleldrive/cuid2"; import { isTruthy } from "remeda"; -import invariant from "tiny-invariant"; import { z } from "zod"; import { IngestError } from "../IngestError"; @@ -32,7 +30,7 @@ export async function publishToRag({ ingestId: string; }) { log.info("Publishing lesson plans and parts to RAG schema"); - // const ingest = await getIngestById({ prisma, ingestId }); + const lessons = await loadLessonsAndUpdateState({ prisma, ingestId, @@ -42,6 +40,9 @@ export async function publishToRag({ log.info(`Loaded ${lessons.length} lessons`); + /** + * Build list of lesson plans to publish + */ const ragLessonPlans: { oakLessonId?: number; oakLessonSlug: string; @@ -53,7 +54,7 @@ export async function publishToRag({ for (const lesson of lessons) { if (!lesson.lessonPlan) { - throw new IngestError("Lessin is missing lesson plan", { + throw new IngestError("Lesson is missing lesson plan", { ingestId, lessonId: lesson.id, }); @@ -71,6 +72,8 @@ export async function publishToRag({ }); } + log.info("About to chunk"); + /** * Add lesson plans to RAG schema */ @@ -78,9 +81,16 @@ export async function publishToRag({ data: ragLessonPlans, chunkSize: 500, fn: async (data) => { - await prisma.ragLessonPlan.createMany({ - data, - }); + try { + log.info(`Writing ${data.length} lesson plans`); + await prisma.ragLessonPlan.createMany({ + data, + }); + log.info(`Written ${data.length} lesson plans`); + } catch (error) { + log.error(error); + throw error; + } }, }); @@ -114,6 +124,9 @@ export async function publishToRag({ const ingestLessonId = ragLessonPlan.ingestLessonId; const lesson = lessons.find((l) => l.id === ingestLessonId); + /** + * @todo this takes ages one by one. group these queries + */ const lessonPlanParts = await prisma.$queryRaw` SELECT key, value_text, value_json, embedding::text FROM ingest.ingest_lesson_plan_part @@ -164,23 +177,84 @@ export async function publishToRag({ const now = new Date().toISOString(); // Need to use $queryRaw because Prisma doesn't support the vector type await prisma.$queryRaw` - INSERT INTO rag.rag_lesson_plan_parts (id, rag_lesson_plan_id, key, value_text, value_json, created_at, updated_at, embedding) - SELECT * - FROM UNNEST ( - ARRAY[${data.map(() => createId())}]::text[], - ARRAY[${data.map((p) => p.ragLessonPlanId)}]::text[], - ARRAY[${data.map((p) => p.key)}]::text[], - ARRAY[${data.map((p) => p.valueText)}]::text[], - ARRAY[${data.map((p) => JSON.stringify(p.valueJson))}]::jsonb[], - ARRAY[${data.map(() => now)}]::timestamp[], - ARRAY[${data.map(() => now)}]::timestamp[], - ARRAY[${data.map((p) => `[${p.embedding.join(",")}]`)}]::vector(256)[] - ); - `; + INSERT INTO rag.rag_lesson_plan_parts (id, rag_lesson_plan_id, key, value_text, value_json, created_at, updated_at, embedding) + SELECT * + FROM UNNEST ( + ARRAY[${data.map(() => createId())}]::text[], + ARRAY[${data.map((p) => p.ragLessonPlanId)}]::text[], + ARRAY[${data.map((p) => p.key)}]::text[], + ARRAY[${data.map((p) => p.valueText)}]::text[], + ARRAY[${data.map((p) => JSON.stringify(p.valueJson))}]::jsonb[], + ARRAY[${data.map(() => now)}]::timestamp[], + ARRAY[${data.map(() => now)}]::timestamp[], + ARRAY[${data.map((p) => { + return "[" + p.embedding.join(",") + "]"; + })}]::vector[] + ); + `; log.info(prisma.$queryRawUnsafe.toString()); }, }); + /** + * In a transaction, update old versions of lesson plans to be unpublished + * and the new versions to be published + */ + await prisma.$transaction([ + prisma.ragLessonPlan.updateMany({ + where: { + oakLessonSlug: { + in: lessons.map((l) => l.data.lessonSlug), + }, + id: { + notIn: lessons.map((l) => l.id), + }, + }, + data: { + isPublished: false, + }, + }), + prisma.ragLessonPlan.updateMany({ + where: { + id: { + in: lessons.map((l) => l.id), + }, + }, + data: { + isPublished: true, + }, + }), + ]); + + // /** + // * Perform check to ensure no duplicate lesson plans are created + // */ + // const duplicateResults = await chunkAndPromiseAll({ + // data: lessons, + // chunkSize: 500, + // fn: async (data) => + // prisma.ragLessonPlan.findMany({ + // where: { + // oakLessonSlug: { + // in: data.map((l) => l.data.lessonSlug), + // }, + // }, + // select: { + // ingestLessonId: true, + // oakLessonId: true, + // oakLessonSlug: true, + // }, + // }), + // }); + // const duplicateLessonPlans = duplicateResults.flat(); + + // if (duplicateLessonPlans.length > 0) { + // log.error(`Duplicate lesson plans found: ${duplicateLessonPlans.length}`); + // throw new IngestError("Duplicate lesson plans found", { + // ingestId, + // }); + // } + log.info("Published lesson plans and parts to RAG schema"); } From 282664c31a0740f22c988899f9416e61798de46f Mon Sep 17 00:00:00 2001 From: mantagen Date: Wed, 18 Dec 2024 16:56:54 +0000 Subject: [PATCH 13/13] use posthog feature flag --- .../core/prompt/builders/AilaLessonPromptBuilder.ts | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts b/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts index 476a3451c..08ba52882 100644 --- a/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts +++ b/packages/aila/src/core/prompt/builders/AilaLessonPromptBuilder.ts @@ -1,3 +1,4 @@ +import { posthogAiBetaServerClient } from "@oakai/core/src/analytics/posthogAiBetaServerClient"; import type { TemplateProps } from "@oakai/core/src/prompts/lesson-assistant"; import { template } from "@oakai/core/src/prompts/lesson-assistant"; import { prisma as globalPrisma } from "@oakai/db/client"; @@ -58,6 +59,10 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder { }> { const noRelevantLessonPlans = "None"; const { chatId, userId } = this._aila; + if (!userId) { + throw new Error("User ID is required to fetch relevant lesson plans"); + } + if (!this._aila?.options.useRag || !chatId) { return { ragLessonPlans: [], @@ -69,7 +74,12 @@ export class AilaLessonPromptBuilder extends AilaPromptBuilder { const NEW_RAG_ENABLED = true; - if (NEW_RAG_ENABLED) { + const newRagEnabled = await posthogAiBetaServerClient.isFeatureEnabled( + "rag-schema-2024-12", + userId, + ); + + if (newRagEnabled) { if (!title || !subject || !keyStage) { log.error( "Missing title, subject or keyStage, returning empty content",