From 4e274ce96df2581c63f531721699977afe87abde Mon Sep 17 00:00:00 2001 From: Gilad S Date: Sun, 24 Sep 2023 01:00:34 +0300 Subject: [PATCH] feat: load conversation history into a `LlamaChatSession` (#51) --- README.md | 28 ++++++++ src/ChatPromptWrapper.ts | 9 +++ src/chatWrappers/ChatMLPromptWrapper.ts | 4 ++ src/chatWrappers/GeneralChatPromptWrapper.ts | 6 +- src/chatWrappers/LlamaChatPromptWrapper.ts | 4 ++ ...erateContextTextFromConversationHistory.ts | 71 +++++++++++++++++++ src/index.ts | 3 +- src/llamaEvaluator/LlamaChatSession.ts | 42 +++++++++-- src/llamaEvaluator/LlamaContext.ts | 18 +++-- src/types.ts | 5 ++ 10 files changed, 177 insertions(+), 13 deletions(-) create mode 100644 src/chatWrappers/generateContextTextFromConversationHistory.ts diff --git a/README.md b/README.md index 43978416..72f90e7c 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,34 @@ const a2 = await session.prompt(q2); console.log("AI: " + a2); ``` +##### Load existing conversation history +```typescript +import {fileURLToPath} from "url"; +import path from "path"; +import {LlamaModel, LlamaContext, LlamaChatSession} from "node-llama-cpp"; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); + +const model = new LlamaModel({ + modelPath: path.join(__dirname, "models", "codellama-13b.Q3_K_M.gguf") +}) +const context = new LlamaContext({model}); +const session = new LlamaChatSession({ + context, + conversationHistory: [{ + prompt: `Remember the number 6 as "The number"`, + response: "OK. I'll remember it" + }] +}); + + +const q2 = 'What is "The number"?'; +console.log("User: " + q2); + +const a2 = await session.prompt(q2); +console.log("AI: " + a2); +``` + #### Raw ```typescript import {fileURLToPath} from "url"; diff --git a/src/ChatPromptWrapper.ts b/src/ChatPromptWrapper.ts index 96fdb934..dd0a6737 100644 --- a/src/ChatPromptWrapper.ts +++ b/src/ChatPromptWrapper.ts @@ -14,4 +14,13 @@ export abstract class ChatPromptWrapper { public getStopStrings(): string[] { return []; } + + public getDefaultStopString(): string { + const stopString = this.getStopStrings()[0]; + + if (stopString == null || stopString.length === 0) + throw new Error(`Prompt wrapper "${this.wrapperName}" has no stop strings`); + + return stopString; + } } diff --git a/src/chatWrappers/ChatMLPromptWrapper.ts b/src/chatWrappers/ChatMLPromptWrapper.ts index 143f1715..afbf2f03 100644 --- a/src/chatWrappers/ChatMLPromptWrapper.ts +++ b/src/chatWrappers/ChatMLPromptWrapper.ts @@ -21,4 +21,8 @@ export class ChatMLPromptWrapper extends ChatPromptWrapper { public override getStopStrings(): string[] { return ["<|im_end|>"]; } + + public override getDefaultStopString(): string { + return "<|im_end|>"; + } } diff --git a/src/chatWrappers/GeneralChatPromptWrapper.ts b/src/chatWrappers/GeneralChatPromptWrapper.ts index e43dde11..c1186514 100644 --- a/src/chatWrappers/GeneralChatPromptWrapper.ts +++ b/src/chatWrappers/GeneralChatPromptWrapper.ts @@ -32,11 +32,15 @@ export class GeneralChatPromptWrapper extends ChatPromptWrapper { ]; } + public override getDefaultStopString(): string { + return `\n\n### ${this._instructionName}`; + } + private _getPromptPrefix(lastStopString: string | null, lastStopStringSuffix: string | null) { return getTextCompletion( lastStopString === "" ? lastStopStringSuffix - : (lastStopString + (lastStopStringSuffix ?? "")), + : ((lastStopString ?? "") + (lastStopStringSuffix ?? "")), [ `\n\n### ${this._instructionName}:\n\n`, `### ${this._instructionName}:\n\n` diff --git a/src/chatWrappers/LlamaChatPromptWrapper.ts b/src/chatWrappers/LlamaChatPromptWrapper.ts index 9ab0f00c..c0f8dc17 100644 --- a/src/chatWrappers/LlamaChatPromptWrapper.ts +++ b/src/chatWrappers/LlamaChatPromptWrapper.ts @@ -21,4 +21,8 @@ export class LlamaChatPromptWrapper extends ChatPromptWrapper { public override getStopStrings(): string[] { return [""]; } + + public override getDefaultStopString(): string { + return ""; + } } diff --git a/src/chatWrappers/generateContextTextFromConversationHistory.ts b/src/chatWrappers/generateContextTextFromConversationHistory.ts new file mode 100644 index 00000000..fe2b3a22 --- /dev/null +++ b/src/chatWrappers/generateContextTextFromConversationHistory.ts @@ -0,0 +1,71 @@ +import {ChatPromptWrapper} from "../ChatPromptWrapper.js"; +import {defaultChatSystemPrompt} from "../config.js"; +import {ConversationInteraction} from "../types.js"; + + +/** + * Generate context text to load into a model context from a conversation history. + * @param {ChatPromptWrapper} chatPromptWrapper + * @param {ConversationInteraction[]} conversationHistory + * @param {object} [options] + * @param {string} [options.systemPrompt] + * @param {number} [options.currentPromptIndex] + * @param {string | null} [options.lastStopString] + * @param {string | null} [options.lastStopStringSuffix] + * @returns {{text: string, stopString: (string | null), stopStringSuffix: (string | null)}} + */ +export function generateContextTextFromConversationHistory( + chatPromptWrapper: ChatPromptWrapper, + conversationHistory: readonly ConversationInteraction[], + { + systemPrompt = defaultChatSystemPrompt, currentPromptIndex = 0, lastStopString = null, lastStopStringSuffix = null + }: { + systemPrompt?: string, currentPromptIndex?: number, lastStopString?: string | null, lastStopStringSuffix?: string | null + } = {} +): { + text: string; + stopString: string | null; + stopStringSuffix: string | null; +} { + let res = ""; + + for (let i = 0; i < conversationHistory.length; i++) { + const interaction = conversationHistory[i]; + const wrappedPrompt = chatPromptWrapper.wrapPrompt(interaction.prompt, { + systemPrompt, + promptIndex: currentPromptIndex, + lastStopString, + lastStopStringSuffix + }); + const stopStrings = chatPromptWrapper.getStopStrings(); + const defaultStopString = chatPromptWrapper.getDefaultStopString(); + const stopStringsToCheckInResponse = new Set([...stopStrings, defaultStopString]); + + currentPromptIndex++; + lastStopString = null; + lastStopStringSuffix = null; + + res += wrappedPrompt; + + for (const stopString of stopStringsToCheckInResponse) { + if (interaction.response.includes(stopString)) { + console.error( + `Stop string "${stopString}" was found in model response of conversation interaction index ${i}`, + {interaction, stopString} + ); + throw new Error("A stop string cannot be in a conversation history interaction model response"); + } + } + + res += interaction.response; + res += defaultStopString; + lastStopString = defaultStopString; + lastStopStringSuffix = ""; + } + + return { + text: res, + stopString: lastStopString, + stopStringSuffix: lastStopStringSuffix + }; +} diff --git a/src/index.ts b/src/index.ts index adcb19da..ededae02 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,7 +10,7 @@ import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper. import {ChatMLPromptWrapper} from "./chatWrappers/ChatMLPromptWrapper.js"; import {getChatWrapperByBos} from "./chatWrappers/createChatWrapperByBos.js"; -import {type Token} from "./types.js"; +import {type ConversationInteraction, type Token} from "./types.js"; export { @@ -22,6 +22,7 @@ export { type LlamaContextOptions, LlamaChatSession, type LlamaChatSessionOptions, + type ConversationInteraction, AbortError, ChatPromptWrapper, EmptyChatPromptWrapper, diff --git a/src/llamaEvaluator/LlamaChatSession.ts b/src/llamaEvaluator/LlamaChatSession.ts index 2708eae1..d47b2d70 100644 --- a/src/llamaEvaluator/LlamaChatSession.ts +++ b/src/llamaEvaluator/LlamaChatSession.ts @@ -4,7 +4,8 @@ import {ChatPromptWrapper} from "../ChatPromptWrapper.js"; import {AbortError} from "../AbortError.js"; import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper.js"; import {getChatWrapperByBos} from "../chatWrappers/createChatWrapperByBos.js"; -import {Token} from "../types.js"; +import {ConversationInteraction, Token} from "../types.js"; +import {generateContextTextFromConversationHistory} from "../chatWrappers/generateContextTextFromConversationHistory.js"; import {LlamaModel} from "./LlamaModel.js"; import {LlamaContext} from "./LlamaContext.js"; @@ -15,7 +16,10 @@ export type LlamaChatSessionOptions = { context: LlamaContext, printLLamaSystemInfo?: boolean, promptWrapper?: ChatPromptWrapper | "auto", - systemPrompt?: string + systemPrompt?: string, + + /** Conversation history to load into the context to continue an existing conversation */ + conversationHistory?: readonly ConversationInteraction[] }; export class LlamaChatSession { @@ -26,17 +30,22 @@ export class LlamaChatSession { private _initialized: boolean = false; private _lastStopString: string | null = null; private _lastStopStringSuffix: string | null = null; + private _conversationHistoryToLoad: readonly ConversationInteraction[] | null = null; private readonly _ctx: LlamaContext; public constructor({ context, printLLamaSystemInfo = false, promptWrapper = new GeneralChatPromptWrapper(), - systemPrompt = defaultChatSystemPrompt + systemPrompt = defaultChatSystemPrompt, + conversationHistory }: LlamaChatSessionOptions) { this._ctx = context; this._printLLamaSystemInfo = printLLamaSystemInfo; this._systemPrompt = systemPrompt; + this._conversationHistoryToLoad = (conversationHistory != null && conversationHistory.length > 0) + ? conversationHistory + : null; if (promptWrapper === "auto") { const chatWrapper = getChatWrapperByBos(context.getBosString()); @@ -76,7 +85,32 @@ export class LlamaChatSession { await this.init(); return await withLock(this, "prompt", async () => { - const promptText = this._promptWrapper.wrapPrompt(prompt, { + let promptText = ""; + + if (this._promptIndex == 0 && this._conversationHistoryToLoad != null) { + const {text, stopString, stopStringSuffix} = + generateContextTextFromConversationHistory(this._promptWrapper, this._conversationHistoryToLoad, { + systemPrompt: this._systemPrompt, + currentPromptIndex: this._promptIndex, + lastStopString: this._lastStopString, + lastStopStringSuffix: this._promptIndex == 0 + ? ( + this._ctx.prependBos + ? this._ctx.getBosString() + : null + ) + : this._lastStopStringSuffix + }); + + promptText += text; + this._lastStopString = stopString; + this._lastStopStringSuffix = stopStringSuffix; + this._promptIndex += this._conversationHistoryToLoad.length; + + this._conversationHistoryToLoad = null; + } + + promptText += this._promptWrapper.wrapPrompt(prompt, { systemPrompt: this._systemPrompt, promptIndex: this._promptIndex, lastStopString: this._lastStopString, diff --git a/src/llamaEvaluator/LlamaContext.ts b/src/llamaEvaluator/LlamaContext.ts index 9717f140..ef2678e9 100644 --- a/src/llamaEvaluator/LlamaContext.ts +++ b/src/llamaEvaluator/LlamaContext.ts @@ -13,13 +13,19 @@ export type LlamaContextOptions = { export class LlamaContext { private readonly _ctx: LLAMAContext; - private _prependBos: boolean; + private readonly _prependBos: boolean; + private _prependTokens: Token[]; public constructor({model, grammar, prependBos = true}: LlamaContextOptions) { this._ctx = new LLAMAContext(model._model, removeNullFields({ grammar: grammar?._grammar })); this._prependBos = prependBos; + this._prependTokens = []; + + if (prependBos) { + this._prependTokens.unshift(this._ctx.tokenBos()); + } } public encode(text: string): Uint32Array { @@ -115,19 +121,18 @@ export class LlamaContext { return this._ctx.getTokenString(nlToken); } - public getContextSize() { + public getContextSize(): number { return this._ctx.getContextSize(); } public async *evaluate(tokens: Uint32Array): AsyncGenerator { let evalTokens = tokens; - if (this._prependBos) { - const tokenArray: Token[] = Array.from(tokens); - tokenArray.unshift(this._ctx.tokenBos()); + if (this._prependTokens.length > 0) { + const tokenArray: Token[] = this._prependTokens.concat(Array.from(tokens)); evalTokens = Uint32Array.from(tokenArray); - this._prependBos = false; + this._prependTokens = []; } // eslint-disable-next-line no-constant-condition @@ -145,5 +150,4 @@ export class LlamaContext { evalTokens = Uint32Array.from([nextToken]); } } - } diff --git a/src/types.ts b/src/types.ts index 3090be56..837e5837 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1 +1,6 @@ export type Token = number; + +export type ConversationInteraction = { + prompt: string, + response: string +};