Skip to content

Commit

Permalink
feat: load conversation history into a LlamaChatSession (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd authored Sep 23, 2023
1 parent 9c8c42b commit 4e274ce
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 13 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,34 @@ const a2 = await session.prompt(q2);
console.log("AI: " + a2);
```

##### Load existing conversation history
```typescript
import {fileURLToPath} from "url";
import path from "path";
import {LlamaModel, LlamaContext, LlamaChatSession} from "node-llama-cpp";

const __dirname = path.dirname(fileURLToPath(import.meta.url));

const model = new LlamaModel({
modelPath: path.join(__dirname, "models", "codellama-13b.Q3_K_M.gguf")
})
const context = new LlamaContext({model});
const session = new LlamaChatSession({
context,
conversationHistory: [{
prompt: `Remember the number 6 as "The number"`,
response: "OK. I'll remember it"
}]
});


const q2 = 'What is "The number"?';
console.log("User: " + q2);

const a2 = await session.prompt(q2);
console.log("AI: " + a2);
```

#### Raw
```typescript
import {fileURLToPath} from "url";
Expand Down
9 changes: 9 additions & 0 deletions src/ChatPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ export abstract class ChatPromptWrapper {
public getStopStrings(): string[] {
return [];
}

public getDefaultStopString(): string {
const stopString = this.getStopStrings()[0];

if (stopString == null || stopString.length === 0)
throw new Error(`Prompt wrapper "${this.wrapperName}" has no stop strings`);

return stopString;
}
}
4 changes: 4 additions & 0 deletions src/chatWrappers/ChatMLPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ export class ChatMLPromptWrapper extends ChatPromptWrapper {
public override getStopStrings(): string[] {
return ["<|im_end|>"];
}

public override getDefaultStopString(): string {
return "<|im_end|>";
}
}
6 changes: 5 additions & 1 deletion src/chatWrappers/GeneralChatPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ export class GeneralChatPromptWrapper extends ChatPromptWrapper {
];
}

public override getDefaultStopString(): string {
return `\n\n### ${this._instructionName}`;
}

private _getPromptPrefix(lastStopString: string | null, lastStopStringSuffix: string | null) {
return getTextCompletion(
lastStopString === "<end>"
? lastStopStringSuffix
: (lastStopString + (lastStopStringSuffix ?? "")),
: ((lastStopString ?? "") + (lastStopStringSuffix ?? "")),
[
`\n\n### ${this._instructionName}:\n\n`,
`### ${this._instructionName}:\n\n`
Expand Down
4 changes: 4 additions & 0 deletions src/chatWrappers/LlamaChatPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ export class LlamaChatPromptWrapper extends ChatPromptWrapper {
public override getStopStrings(): string[] {
return ["</s>"];
}

public override getDefaultStopString(): string {
return "</s>";
}
}
71 changes: 71 additions & 0 deletions src/chatWrappers/generateContextTextFromConversationHistory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
import {defaultChatSystemPrompt} from "../config.js";
import {ConversationInteraction} from "../types.js";


/**
* Generate context text to load into a model context from a conversation history.
* @param {ChatPromptWrapper} chatPromptWrapper
* @param {ConversationInteraction[]} conversationHistory
* @param {object} [options]
* @param {string} [options.systemPrompt]
* @param {number} [options.currentPromptIndex]
* @param {string | null} [options.lastStopString]
* @param {string | null} [options.lastStopStringSuffix]
* @returns {{text: string, stopString: (string | null), stopStringSuffix: (string | null)}}
*/
export function generateContextTextFromConversationHistory(
chatPromptWrapper: ChatPromptWrapper,
conversationHistory: readonly ConversationInteraction[],
{
systemPrompt = defaultChatSystemPrompt, currentPromptIndex = 0, lastStopString = null, lastStopStringSuffix = null
}: {
systemPrompt?: string, currentPromptIndex?: number, lastStopString?: string | null, lastStopStringSuffix?: string | null
} = {}
): {
text: string;
stopString: string | null;
stopStringSuffix: string | null;
} {
let res = "";

for (let i = 0; i < conversationHistory.length; i++) {
const interaction = conversationHistory[i];
const wrappedPrompt = chatPromptWrapper.wrapPrompt(interaction.prompt, {
systemPrompt,
promptIndex: currentPromptIndex,
lastStopString,
lastStopStringSuffix
});
const stopStrings = chatPromptWrapper.getStopStrings();
const defaultStopString = chatPromptWrapper.getDefaultStopString();
const stopStringsToCheckInResponse = new Set([...stopStrings, defaultStopString]);

currentPromptIndex++;
lastStopString = null;
lastStopStringSuffix = null;

res += wrappedPrompt;

for (const stopString of stopStringsToCheckInResponse) {
if (interaction.response.includes(stopString)) {
console.error(
`Stop string "${stopString}" was found in model response of conversation interaction index ${i}`,
{interaction, stopString}
);
throw new Error("A stop string cannot be in a conversation history interaction model response");
}
}

res += interaction.response;
res += defaultStopString;
lastStopString = defaultStopString;
lastStopStringSuffix = "";
}

return {
text: res,
stopString: lastStopString,
stopStringSuffix: lastStopStringSuffix
};
}
3 changes: 2 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper.
import {ChatMLPromptWrapper} from "./chatWrappers/ChatMLPromptWrapper.js";
import {getChatWrapperByBos} from "./chatWrappers/createChatWrapperByBos.js";

import {type Token} from "./types.js";
import {type ConversationInteraction, type Token} from "./types.js";


export {
Expand All @@ -22,6 +22,7 @@ export {
type LlamaContextOptions,
LlamaChatSession,
type LlamaChatSessionOptions,
type ConversationInteraction,
AbortError,
ChatPromptWrapper,
EmptyChatPromptWrapper,
Expand Down
42 changes: 38 additions & 4 deletions src/llamaEvaluator/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
import {AbortError} from "../AbortError.js";
import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper.js";
import {getChatWrapperByBos} from "../chatWrappers/createChatWrapperByBos.js";
import {Token} from "../types.js";
import {ConversationInteraction, Token} from "../types.js";
import {generateContextTextFromConversationHistory} from "../chatWrappers/generateContextTextFromConversationHistory.js";
import {LlamaModel} from "./LlamaModel.js";
import {LlamaContext} from "./LlamaContext.js";

Expand All @@ -15,7 +16,10 @@ export type LlamaChatSessionOptions = {
context: LlamaContext,
printLLamaSystemInfo?: boolean,
promptWrapper?: ChatPromptWrapper | "auto",
systemPrompt?: string
systemPrompt?: string,

/** Conversation history to load into the context to continue an existing conversation */
conversationHistory?: readonly ConversationInteraction[]
};

export class LlamaChatSession {
Expand All @@ -26,17 +30,22 @@ export class LlamaChatSession {
private _initialized: boolean = false;
private _lastStopString: string | null = null;
private _lastStopStringSuffix: string | null = null;
private _conversationHistoryToLoad: readonly ConversationInteraction[] | null = null;
private readonly _ctx: LlamaContext;

public constructor({
context,
printLLamaSystemInfo = false,
promptWrapper = new GeneralChatPromptWrapper(),
systemPrompt = defaultChatSystemPrompt
systemPrompt = defaultChatSystemPrompt,
conversationHistory
}: LlamaChatSessionOptions) {
this._ctx = context;
this._printLLamaSystemInfo = printLLamaSystemInfo;
this._systemPrompt = systemPrompt;
this._conversationHistoryToLoad = (conversationHistory != null && conversationHistory.length > 0)
? conversationHistory
: null;

if (promptWrapper === "auto") {
const chatWrapper = getChatWrapperByBos(context.getBosString());
Expand Down Expand Up @@ -76,7 +85,32 @@ export class LlamaChatSession {
await this.init();

return await withLock(this, "prompt", async () => {
const promptText = this._promptWrapper.wrapPrompt(prompt, {
let promptText = "";

if (this._promptIndex == 0 && this._conversationHistoryToLoad != null) {
const {text, stopString, stopStringSuffix} =
generateContextTextFromConversationHistory(this._promptWrapper, this._conversationHistoryToLoad, {
systemPrompt: this._systemPrompt,
currentPromptIndex: this._promptIndex,
lastStopString: this._lastStopString,
lastStopStringSuffix: this._promptIndex == 0
? (
this._ctx.prependBos
? this._ctx.getBosString()
: null
)
: this._lastStopStringSuffix
});

promptText += text;
this._lastStopString = stopString;
this._lastStopStringSuffix = stopStringSuffix;
this._promptIndex += this._conversationHistoryToLoad.length;

this._conversationHistoryToLoad = null;
}

promptText += this._promptWrapper.wrapPrompt(prompt, {
systemPrompt: this._systemPrompt,
promptIndex: this._promptIndex,
lastStopString: this._lastStopString,
Expand Down
18 changes: 11 additions & 7 deletions src/llamaEvaluator/LlamaContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ export type LlamaContextOptions = {

export class LlamaContext {
private readonly _ctx: LLAMAContext;
private _prependBos: boolean;
private readonly _prependBos: boolean;
private _prependTokens: Token[];

public constructor({model, grammar, prependBos = true}: LlamaContextOptions) {
this._ctx = new LLAMAContext(model._model, removeNullFields({
grammar: grammar?._grammar
}));
this._prependBos = prependBos;
this._prependTokens = [];

if (prependBos) {
this._prependTokens.unshift(this._ctx.tokenBos());
}
}

public encode(text: string): Uint32Array {
Expand Down Expand Up @@ -115,19 +121,18 @@ export class LlamaContext {
return this._ctx.getTokenString(nlToken);
}

public getContextSize() {
public getContextSize(): number {
return this._ctx.getContextSize();
}

public async *evaluate(tokens: Uint32Array): AsyncGenerator<Token, void> {
let evalTokens = tokens;

if (this._prependBos) {
const tokenArray: Token[] = Array.from(tokens);
tokenArray.unshift(this._ctx.tokenBos());
if (this._prependTokens.length > 0) {
const tokenArray: Token[] = this._prependTokens.concat(Array.from(tokens));

evalTokens = Uint32Array.from(tokenArray);
this._prependBos = false;
this._prependTokens = [];
}

// eslint-disable-next-line no-constant-condition
Expand All @@ -145,5 +150,4 @@ export class LlamaContext {
evalTokens = Uint32Array.from([nextToken]);
}
}

}
5 changes: 5 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
export type Token = number;

export type ConversationInteraction = {
prompt: string,
response: string
};

0 comments on commit 4e274ce

Please sign in to comment.