From 63975c939639a94bd82ee49c31ef247bb8bbcbba Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 6 Sep 2024 13:25:01 +0200 Subject: [PATCH] feat(memory): add sync mechanism to the TokenMemory --- src/memory/tokenMemory.test.ts | 80 +++++++++++++++++++++++++++++++--- src/memory/tokenMemory.ts | 62 +++++++++++++++++++++----- 2 files changed, 125 insertions(+), 17 deletions(-) diff --git a/src/memory/tokenMemory.test.ts b/src/memory/tokenMemory.test.ts index 96896eeb..6c72790d 100644 --- a/src/memory/tokenMemory.test.ts +++ b/src/memory/tokenMemory.test.ts @@ -23,7 +23,12 @@ import * as R from "remeda"; import { verifyDeserialization } from "@tests/e2e/utils.js"; describe("Token Memory", () => { - const getInstance = () => { + const getInstance = (config: { + llmFactor: number; + localFactor: number; + syncThreshold: number; + maxTokens: number; + }) => { const llm = new BAMChatLLM({ llm: new BAMLLM({ client: new Client(), @@ -35,22 +40,87 @@ describe("Token Memory", () => { }, }); + const estimateLLM = (msg: BaseMessage) => Math.ceil(msg.text.length * config.llmFactor); + const estimateLocal = (msg: BaseMessage) => Math.ceil(msg.text.length * config.localFactor); + vi.spyOn(llm, "tokenize").mockImplementation(async (messages: BaseMessage[]) => ({ - tokensCount: R.sum(messages.map((msg) => [msg.role, msg.text].join("").length)), + tokensCount: R.sum(messages.map(estimateLLM)), })); return new TokenMemory({ llm, - maxTokens: 1000, + maxTokens: config.maxTokens, + syncThreshold: config.syncThreshold, + handlers: { + estimate: estimateLocal, + }, }); }; + it("Auto sync", async () => { + const instance = getInstance({ + llmFactor: 2, + localFactor: 1, + maxTokens: 4, + syncThreshold: 0.5, + }); + await instance.addMany([ + BaseMessage.of({ role: Role.USER, text: "A" }), + BaseMessage.of({ role: Role.USER, text: "B" }), + BaseMessage.of({ role: Role.USER, text: "C" }), + BaseMessage.of({ role: Role.USER, text: "D" }), + ]); + expect(instance.stats()).toMatchObject({ + isDirty: false, + tokensUsed: 4, + messagesCount: 2, + }); + }); + + it("Synchronizes", async () => { + const instance = getInstance({ + llmFactor: 2, + localFactor: 1, + maxTokens: 10, + syncThreshold: 1, + }); + expect(instance.stats()).toMatchObject({ + isDirty: false, + tokensUsed: 0, + messagesCount: 0, + }); + await instance.addMany([ + BaseMessage.of({ role: Role.USER, text: "A" }), + BaseMessage.of({ role: Role.USER, text: "B" }), + BaseMessage.of({ role: Role.USER, text: "C" }), + BaseMessage.of({ role: Role.USER, text: "D" }), + BaseMessage.of({ role: Role.USER, text: "E" }), + BaseMessage.of({ role: Role.USER, text: "F" }), + ]); + expect(instance.stats()).toMatchObject({ + isDirty: true, + tokensUsed: 6, + messagesCount: 6, + }); + await instance.sync(); + expect(instance.stats()).toMatchObject({ + isDirty: false, + tokensUsed: 10, + messagesCount: 5, + }); + }); + it("Serializes", async () => { vi.stubEnv("GENAI_API_KEY", "123"); - const instance = getInstance(); + const instance = getInstance({ + llmFactor: 2, + localFactor: 1, + maxTokens: 10, + syncThreshold: 1, + }); await instance.add( BaseMessage.of({ - text: "I am a Batman!", + text: "Hello!", role: Role.USER, }), ); diff --git a/src/memory/tokenMemory.ts b/src/memory/tokenMemory.ts index e394d44c..f5c5b25e 100644 --- a/src/memory/tokenMemory.ts +++ b/src/memory/tokenMemory.ts @@ -20,26 +20,34 @@ import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; import * as R from "remeda"; import { shallowCopy } from "@/serializer/utils.js"; import { removeFromArray } from "@/internals/helpers/array.js"; +import { sum } from "remeda"; export interface Handlers { + estimate: (messages: BaseMessage) => number; removalSelector: (messages: BaseMessage[]) => BaseMessage; } export interface TokenMemoryInput { llm: ChatLLM; maxTokens?: number; + syncThreshold?: number; capacityThreshold?: number; - handlers?: Handlers; + handlers?: Partial; +} + +interface TokenByMessage { + tokensCount: number; + dirty: boolean; } export class TokenMemory extends BaseMemory { public readonly messages: BaseMessage[] = []; protected llm: ChatLLM; - protected threshold = 1; + protected threshold; + protected syncThreshold; protected maxTokens: number | null = null; - protected tokensUsed = 0; - protected tokensByMessage = new WeakMap(); + protected tokensByMessage = new WeakMap(); public readonly handlers: Handlers; constructor(config: TokenMemoryInput) { @@ -47,8 +55,11 @@ export class TokenMemory extends BaseMemory { this.llm = config.llm; this.maxTokens = config.maxTokens ?? null; this.threshold = config.capacityThreshold ?? 0.75; + this.syncThreshold = config.syncThreshold ?? 0.25; this.handlers = { ...config?.handlers, + estimate: + config?.handlers?.estimate || ((msg) => Math.ceil((msg.role.length + msg.text.length) / 4)), removalSelector: config.handlers?.removalSelector || ((messages) => messages[0]), }; if (!R.clamp({ min: 0, max: 1 })(this.threshold)) { @@ -60,13 +71,24 @@ export class TokenMemory extends BaseMemory { this.register(); } + get tokensUsed(): number { + return sum(this.messages.map((msg) => this.tokensByMessage.get(msg)!.tokensCount!)); + } + + get isDirty(): boolean { + return this.messages.some((msg) => this.tokensByMessage.get(msg)?.dirty !== false); + } + async add(message: BaseMessage) { if (this.maxTokens === null) { const meta = await this.llm.meta(); this.maxTokens = Math.ceil((meta.tokenLimit ?? Infinity) * this.threshold); } - const meta = await this.llm.tokenize([message]); + const meta = this.tokensByMessage.has(message) + ? this.tokensByMessage.get(message)! + : { tokensCount: this.handlers.estimate(message), dirty: true }; + if (meta.tokensCount > this.maxTokens) { throw new MemoryFatalError( `Retrieved message (${meta.tokensCount} tokens) cannot fit inside current memory (${this.maxTokens} tokens)`, @@ -80,14 +102,30 @@ export class TokenMemory extends BaseMemory { if (!messageToDelete || !exists) { throw new MemoryFatalError('The "removalSelector" handler must return a valid message!'); } - - const tokensCount = this.tokensByMessage.get(messageToDelete) ?? 0; - this.tokensUsed -= tokensCount; } - this.tokensUsed += meta.tokensCount; - this.tokensByMessage.set(message, meta.tokensCount); + this.tokensByMessage.set(message, meta); this.messages.push(message); + + if (this.isDirty && this.tokensUsed / this.maxTokens >= this.syncThreshold) { + await this.sync(); + } + } + + async sync() { + const messages = await Promise.all( + this.messages.map(async (msg) => { + const cache = this.tokensByMessage.get(msg); + if (cache?.dirty !== false) { + const result = await this.llm.tokenize([msg]); + this.tokensByMessage.set(msg, { tokensCount: result.tokensCount, dirty: false }); + } + return msg; + }), + ); + + this.messages.length = 0; + await this.addMany(messages); } reset() { @@ -95,7 +133,6 @@ export class TokenMemory extends BaseMemory { this.tokensByMessage.delete(msg); } this.messages.length = 0; - this.tokensUsed = 0; } stats() { @@ -103,15 +140,16 @@ export class TokenMemory extends BaseMemory { tokensUsed: this.tokensUsed, maxTokens: this.maxTokens, messagesCount: this.messages.length, + isDirty: this.isDirty, }; } createSnapshot() { return { - tokensUsed: this.tokensUsed, llm: this.llm, maxTokens: this.maxTokens, threshold: this.threshold, + syncThreshold: this.syncThreshold, messages: shallowCopy(this.messages), handlers: this.handlers, tokensByMessage: this.messages