Skip to content

Commit

Permalink
fix: handle stop words remainder properly in a chat session (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd authored Sep 2, 2023
1 parent dd49959 commit 9bdef11
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
35 changes: 26 additions & 9 deletions src/llamaEvaluator/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ export class LlamaChatSession {
onToken, signal, maxTokens
}: { onToken?(tokens: Token[]): void, signal?: AbortSignal, maxTokens?: number } = {}) {
const stopStrings = this._promptWrapper.getStopStrings();
const stopStringIndexes = Array(stopStrings.length).fill(0);
const stopStringIndexes: number[] = Array(stopStrings.length).fill(0);
const skippedChunksQueue: Token[] = [];
const res: Token[] = [];

Expand All @@ -114,14 +114,32 @@ export class LlamaChatSession {
throw new AbortError();

const tokenStr = this._ctx.decode(Uint32Array.from([chunk]));
const {shouldReturn, skipTokenEvent, stopString, stopStringSuffix} = this._checkStopString(tokenStr, stopStringIndexes);
const {
shouldReturn, skipTokenEvent, stopString, stopStringSuffix
} = this._checkStopString(tokenStr, stopStrings, stopStringIndexes);

if (shouldReturn) {
skippedChunksQueue.push(chunk);
const skippedChunksText = skippedChunksQueue.length > 0
? this._ctx.decode(Uint32Array.from(skippedChunksQueue))
: "";

const [queuedTextBeforeStopString] = skippedChunksText.split(stopString);

if (queuedTextBeforeStopString.length > 0) {
const beforeStopStringTokens: Token[] = Array.from(this._ctx.encode(queuedTextBeforeStopString));

res.push(...beforeStopStringTokens);
onToken?.(beforeStopStringTokens);
skippedChunksQueue.length = 0;
}

if (shouldReturn)
return {
text: this._ctx.decode(Uint32Array.from(res)),
stopString,
stopStringSuffix
};
}

// if the token is unknown, it means it's not complete character
if (tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent) {
Expand Down Expand Up @@ -149,32 +167,31 @@ export class LlamaChatSession {
};
}

private _checkStopString(tokenStr: string, stopStringIndexes: number[]){
const stopStrings = this._promptWrapper.getStopStrings();
private _checkStopString(tokenStr: string, stopStrings: string[], stopStringIndexes: number[]){
let skipTokenEvent = false;

for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) {
const stopString = stopStrings[stopStringIndex];

let localShouldSkipTokenEvent = false;
for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
let i = 0;
for (; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) {
stopStringIndexes[stopStringIndex]++;
localShouldSkipTokenEvent = true;
} else {
stopStringIndexes[stopStringIndex] = 0;
localShouldSkipTokenEvent = false;
break;
}
}

if (stopStringIndexes[stopStringIndex] === stopString.length) {
return {
shouldReturn: true,
stopString,
stopStringSuffix: tokenStr.length === stopString.length
stopStringSuffix: tokenStr.length === i
? null
: tokenStr.slice(stopString.length)
: tokenStr.slice(i)
};
}

Expand Down
7 changes: 5 additions & 2 deletions src/llamaEvaluator/LlamaContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ export class LlamaContext {
return this._ctx.encode(text);
}

public decode(tokens: Uint32Array): string {
public decode(tokens: Uint32Array | Token[]): string {
if (tokens.length === 0)
return "";

return this._ctx.decode(tokens);
if (tokens instanceof Uint32Array)
return this._ctx.decode(tokens);

return this._ctx.decode(Uint32Array.from(tokens));
}

public get prependBos() {
Expand Down

0 comments on commit 9bdef11

Please sign in to comment.