From 23c3a059284189ad7fc2cdb9f2943a3076e59593 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 24 May 2024 01:53:15 +0300 Subject: [PATCH 01/39] refactor: split `LlamaChat` implementation into smaller functions --- src/evaluator/LlamaChat/LlamaChat.ts | 1607 ++++++++++++++++---------- src/evaluator/LlamaModel.ts | 7 +- 2 files changed, 971 insertions(+), 643 deletions(-) diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index 2e7a1853..4d03b91b 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -8,7 +8,7 @@ import {removeNullFields} from "../../utils/removeNullFields.js"; import {LlamaGrammarEvaluationState} from "../LlamaGrammarEvaluationState.js"; import {LlamaText} from "../../utils/LlamaText.js"; import {StopGenerationDetector} from "../../utils/StopGenerationDetector.js"; -import {QueuedTokenReleaseLock, TokenStreamRegulator} from "../../utils/TokenStreamRegulator.js"; +import {QueuedTokenRelease, QueuedTokenReleaseLock, TokenStreamRegulator} from "../../utils/TokenStreamRegulator.js"; import {EvaluationPriority} from "../LlamaContext/types.js"; import {UNKNOWN_UNICODE_CHAR} from "../../consts.js"; import {getQueuedTokensBeforeStopTrigger} from "../../utils/getQueuedTokensBeforeStopTrigger.js"; @@ -174,6 +174,7 @@ const defaultContextShiftOptions: Required = { strategy: "eraseFirstResponseAndKeepFirstSystem", lastEvaluationMetadata: null }; +const defaultRepeatPenaltyLastTokens = 64; export class LlamaChat { @@ -260,6 +261,14 @@ export class LlamaChat { } public async generateResponse( + history: ChatHistoryItem[], + options: LLamaChatGenerateResponseOptions = {} + ): Promise> { + return this._generateResponse(history, options); + } + + /** @internal */ + private async _generateResponse( history: ChatHistoryItem[], { onToken, @@ -285,673 +294,106 @@ export class LlamaChat { } = {} }: LLamaChatGenerateResponseOptions = {} ): Promise> { - const functionsEnabled = (functions != null && Object.keys(functions).length > 0); - - if (grammar != null && functionsEnabled) - throw new Error("Using both grammar and functions is not supported yet"); - - if (signal?.aborted) - throw signal.reason; - - if (this._sequence == null) - throw new DisposedError(); - - let resolvedHistory = this._sequence.isLoadedToMemory - ? history.slice() - : history.map(removeRawFromHistoryItem); - - if (resolvedHistory.length === 0 || resolvedHistory[resolvedHistory.length - 1].type !== "model") - resolvedHistory.push({ - type: "model", - response: [] - }); - - const model = this._sequence.model; - const context = this._sequence.context; - const resolvedContextShift = { - ...defaultContextShiftOptions, - ...removeNullFields(contextShift) - }; - const { - lastTokens: repeatPenaltyLastTokens = 64, - punishTokensFilter, - penalizeNewLine, - penalty, - frequencyPenalty, - presencePenalty - }: LLamaContextualRepeatPenalty = repeatPenalty === false - ? {lastTokens: 0} - : repeatPenalty; - const lastModelResponse = getLastTextModelResponseFromChatHistory(resolvedHistory); - - const res: Token[] = []; - const pendingTokens: Token[] = []; - let ignoredStartTextTokens: Token[] = []; - const functionCallTokens: Token[] = []; - const repeatPenaltyEnabled = repeatPenaltyLastTokens > 0; - const grammarEvaluationState = grammar != null - ? new LlamaGrammarEvaluationState({grammar}) - : undefined; - let functionsGrammar = functionsEnabled - ? new FunctionCallGrammar(model._llama, functions as NonNullable, this._chatWrapper, false) - : undefined; - let functionsEvaluationState = (functionsEnabled && functionsGrammar != null) - ? new LlamaGrammarEvaluationState({ - grammar: functionsGrammar - }) - : undefined; - const streamRegulator = new TokenStreamRegulator(); - const stopGenerationDetector = new StopGenerationDetector(); - const customStopGenerationTriggersDetector = new StopGenerationDetector(); - const functionSyntaxStartDetector = new StopGenerationDetector(); - const functionSyntaxEndDetector = new StopGenerationDetector(); - const disengageInitiallyEngagedFunctionMode = new StopGenerationDetector(); - const ignoreStartTextDetector = new StopGenerationDetector(); - const locksToReleaseOnValidGeneration: QueuedTokenReleaseLock[] = []; - const functionCallTokenSyntaxLocks: QueuedTokenReleaseLock[] = []; - - let generatedTokens = 0; - let isFirstEvaluation = true; - let inFunctionEvaluationMode = false; - let initiallyEngagedFunctionMode = false; - let lastContextWindowHistory: ChatHistoryItem[] = resolvedHistory; - let lastHistoryCompressionMetadata: object | null | undefined = resolvedContextShift.lastEvaluationMetadata; - - const ensureNotAborted = () => { - if (signal?.aborted && (!stopOnAbortSignal || res.length === 0)) - throw signal.reason; - - if (this._sequence == null) - throw new DisposedError(); - }; - - const getPenaltyTokens = () => { - if (this._sequence == null) - throw new DisposedError(); - - let punishTokens = res.slice(-repeatPenaltyLastTokens); - - if (punishTokensFilter != null) - punishTokens = punishTokensFilter(punishTokens); - - if (penalizeNewLine == null || !penalizeNewLine) { - const nlToken = model.tokens.nl; - - if (nlToken != null) - punishTokens = punishTokens.filter(token => token !== nlToken); - } - - return punishTokens; - }; - - const getResolvedHistoryWithCurrentModelResponse = () => { - if (res.length === 0) - return resolvedHistory; - - let modelResponse = model.detokenize(res); - - if (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) - modelResponse = modelResponse.trimEnd(); - - if (modelResponse === "") - return resolvedHistory; - - return setLastModelTextResponseInChatHistory( - resolvedHistory, - lastModelResponse + modelResponse - ); - }; - - const removeFoundStartIgnoreTextsFromPendingTokens = () => { - if (res.length === 0 && pendingTokens.length > 0) { - ignoreStartTextDetector.clearInProgressStops(); - ignoreStartTextDetector.clearTriggeredStops(); - - let mostExhaustiveTriggeredStops: ReturnType | null = null; - - for (let i = 0; i < pendingTokens.length; i++) { - ignoreStartTextDetector.recordGeneration({ - text: model.detokenize([pendingTokens[i]]), - tokens: [pendingTokens[i]], - startNewChecks: i === 0 - }); - - if (ignoreStartTextDetector.hasTriggeredStops) { - mostExhaustiveTriggeredStops = ignoreStartTextDetector.getTriggeredStops(); - ignoreStartTextDetector.clearTriggeredStops(); - } else if (!ignoreStartTextDetector.hasInProgressStops) - break; - } - - if (mostExhaustiveTriggeredStops != null) { - const [mostExhaustiveTriggeredStop] = mostExhaustiveTriggeredStops; - - if (mostExhaustiveTriggeredStop != null) { - ignoredStartTextTokens = mostExhaustiveTriggeredStop.stopTrigger - .map((stopTrigger) => { - if (typeof stopTrigger === "string") - return model.tokenize(stopTrigger, false, "trimLeadingSpace"); - else - return [stopTrigger]; - }) - .flat(1); - - const newPendingTokens = mostExhaustiveTriggeredStop.remainingGenerations - .map((generation) => { - if (typeof generation === "string") - return model.tokenize(generation, false, "trimLeadingSpace"); - else - return generation; - }) - .flat(1); - pendingTokens.length = 0; - pendingTokens.push(...newPendingTokens); - } + const generateResponseState = new GenerateResponseState( + this, + this._chatWrapper, + history, + { + onToken, + signal, + stopOnAbortSignal, + maxTokens, + temperature, + minP, + topK, + topP, + grammar: grammar as never, + trimWhitespaceSuffix, + repeatPenalty, + tokenBias, + evaluationPriority, + functions, + documentFunctionParams, + contextShift, + customStopTriggers, + lastEvaluationContextWindow: { + history: lastEvaluationContextWindowHistory, + minimumOverlapPercentageToPreventContextShift } } - }; - - if (customStopTriggers != null) - StopGenerationDetector.resolveStopTriggers(customStopTriggers, model.tokenizer) - .map((stopTrigger) => customStopGenerationTriggersDetector.addStopTrigger(stopTrigger)); + ); - if (grammar != null) - StopGenerationDetector.resolveStopTriggers(grammar.stopGenerationTriggers, model.tokenizer) - .map((stopTrigger) => stopGenerationDetector.addStopTrigger(stopTrigger)); + try { + generateResponseState.ensureLastHistoryItemIsModel(); - if (functions != null && Object.keys(functions).length > 0) - functionSyntaxStartDetector.addStopTrigger([this._chatWrapper.settings.functions.call.prefix]); + // eslint-disable-next-line no-constant-condition + while (true) { + generateResponseState.startTokenLoop(); + await generateResponseState.loadContextWindow(); - // eslint-disable-next-line no-constant-condition - while (true) { - ensureNotAborted(); + if (generateResponseState.generatedTokens === 0) { + generateResponseState.addIgnoreStartTextTriggersFromChatWrapper(); + generateResponseState.addFunctionSyntaxEndTriggersFromFunctionsGrammar(); - let shouldContextShift = false; - const queuedChunkTokens = streamRegulator.getAllQueuedChunkTokens(); - const { - history: contextWindowHistory, - stopGenerationTriggers, - tokens: contextWindowTokens, - newResolvedHistory, - newHistoryCompressionMetadata, - ignoreStartText, - functionCallInitiallyEngaged, - disengageInitiallyEngagedFunctionCall - } = await getContextWindow({ - resolvedHistory: getResolvedHistoryWithCurrentModelResponse(), - resolvedContextShift, - lastHistoryCompressionMetadata, - pendingTokensCount: ignoredStartTextTokens.length + pendingTokens.length + queuedChunkTokens.length, - isFirstEvaluation, - chatWrapper: this._chatWrapper, - lastEvaluationContextWindowHistory, - minimumOverlapPercentageToPreventContextShift, - sequence: this._sequence, - minFreeContextTokens: 1, - functions: functionsEnabled ? functions : undefined, - documentFunctionParams - }); - ensureNotAborted(); - - if (generatedTokens === 0) { - StopGenerationDetector.resolveStopTriggers(ignoreStartText, model.tokenizer) - .map((stopTrigger) => ignoreStartTextDetector.addStopTrigger(stopTrigger)); - - if (functionsEnabled) { - initiallyEngagedFunctionMode = functionCallInitiallyEngaged; - StopGenerationDetector.resolveStopTriggers(disengageInitiallyEngagedFunctionCall, model.tokenizer) - .map((stopTrigger) => disengageInitiallyEngagedFunctionMode.addStopTrigger(stopTrigger)); - - if (initiallyEngagedFunctionMode) { - inFunctionEvaluationMode = true; - functionsGrammar = new FunctionCallGrammar( - model._llama, - functions as NonNullable, - this._chatWrapper, - true - ); - functionsEvaluationState = new LlamaGrammarEvaluationState({ - grammar: functionsGrammar - }); + if (generateResponseState.functionsEnabled) { + generateResponseState.initFunctions(); } } - } - const tokens = [...contextWindowTokens, ...ignoredStartTextTokens, ...pendingTokens, ...queuedChunkTokens]; - resolvedHistory = newResolvedHistory; - lastHistoryCompressionMetadata = newHistoryCompressionMetadata; - lastContextWindowHistory = contextWindowHistory; - const contextWindowLastModelResponse = getLastTextModelResponseFromChatHistory(contextWindowHistory); - const contextWindowsRes: Token[] = []; + generateResponseState.addStopGenerationTriggersFromChatWrapper(); + await generateResponseState.alignCurrentSequenceStateWithCurrentTokens(); - StopGenerationDetector.resolveStopTriggers(stopGenerationTriggers, model.tokenizer) - .map((stopTrigger) => stopGenerationDetector.addStopTrigger(stopTrigger)); + await generateResponseState.createNewEvaluationIterator(); + while (await generateResponseState.iterateEvaluation()) { + generateResponseState.waitOnPartialCharactersOrWhiteSpaceTokens(); - if (functionsGrammar != null) - StopGenerationDetector.resolveStopTriggers(functionsGrammar.stopGenerationTriggers, model.tokenizer) - .map((stopTrigger) => functionSyntaxEndDetector.addStopTrigger(stopTrigger)); + generateResponseState.trackGenerationForDisengageInitiallyEngagedFunctionMode(); + generateResponseState.trackFunctionSyntaxStart(); - let {firstDifferentIndex} = this._sequence.compareContextTokens(tokens); + generateResponseState.handleInitiallyEngagedFunctionModeFunctionDetection(); + generateResponseState.handleFunctionSyntax(); - // we need to decode at least one token to generate a response - if (firstDifferentIndex === tokens.length && firstDifferentIndex > 0) - firstDifferentIndex -= 1; + const functionEndSyntaxRes = generateResponseState.detectFunctionEndSyntax(); + if (functionEndSyntaxRes != null) + return functionEndSyntaxRes; - tokens.splice(0, firstDifferentIndex); + generateResponseState.recordStopGenerationEvaluation(); - if (firstDifferentIndex < this._sequence.nextTokenIndex) { - await this._sequence.eraseContextTokenRanges([{ - start: firstDifferentIndex, - end: this._sequence.nextTokenIndex - }]); - ensureNotAborted(); - } + generateResponseState.popStreamRegulatorFreeTokens(); + generateResponseState.removeFoundStartIgnoreTextsFromPendingTokens(); + const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger(); + if (stopGenerationTriggerRes != null) + return stopGenerationTriggerRes; - const evaluationIterator = this._sequence.evaluate(tokens, removeNullFields({ - temperature, minP, topK, topP, - grammarEvaluationState: () => { - if (inFunctionEvaluationMode) - return functionsEvaluationState; + generateResponseState.spliceIgnoreStartTextDetectedTokens(); - return grammarEvaluationState; - }, - repeatPenalty: !repeatPenaltyEnabled ? undefined : { - punishTokens: getPenaltyTokens, - penalty, - frequencyPenalty, - presencePenalty - }, - tokenBias, - evaluationPriority, - yieldEogToken: true - })); - - try { - let currentIteration = await evaluationIterator.next(); - while (currentIteration.done !== true) { - const token = currentIteration.value; - let replacementToken: Token | undefined = undefined; - - ensureNotAborted(); - generatedTokens++; - - const tokens = [token]; - const text = model.detokenize([token]); - const queuedTokenRelease = streamRegulator.addChunk({tokens, text}); - - if (initiallyEngagedFunctionMode) - disengageInitiallyEngagedFunctionMode.recordGeneration({text, tokens, startNewChecks: generatedTokens === 1}); - - if (text === UNKNOWN_UNICODE_CHAR || ( - (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) && text.trim() === "" - )) { - locksToReleaseOnValidGeneration.push(queuedTokenRelease.createTextIndexLock(0)); - } else { - while (locksToReleaseOnValidGeneration.length > 0) - locksToReleaseOnValidGeneration.shift()!.dispose(); - } + generateResponseState.moveFreePendingTokensToRes(); - functionSyntaxStartDetector.recordGeneration({text, tokens, queuedTokenRelease}); - - if (initiallyEngagedFunctionMode && disengageInitiallyEngagedFunctionMode.hasTriggeredStops) { - initiallyEngagedFunctionMode = false; - - let shouldStopFunctionEvaluationMode = !functionSyntaxStartDetector.hasTriggeredStops; - - if (!shouldStopFunctionEvaluationMode && functionsEnabled && functionsGrammar != null) { - const functionCallText = model.detokenize([...functionCallTokens, ...tokens]); - - try { - const functionName = functionsGrammar.parseFunctionNameFromPartialCall(functionCallText, { - enableInternalBuiltinFunctions: true, - initialFunctionCallEngaged: true - }); - - const internalBuiltinFunctions = - this._chatWrapper.getInternalBuiltinFunctions({initialFunctionCallEngaged: true}); - if (internalBuiltinFunctions[functionName] != null) { - shouldStopFunctionEvaluationMode = true; - } - } catch (err) { - if (!(err instanceof LlamaFunctionCallValidationError)) - throw err; - } - } - - if (shouldStopFunctionEvaluationMode) { - inFunctionEvaluationMode = false; - functionsGrammar = new FunctionCallGrammar( - model._llama, - functions as NonNullable, - this._chatWrapper, false - ); - functionsEvaluationState = new LlamaGrammarEvaluationState({ - grammar: functionsGrammar - }); - - functionCallTokens.length = 0; - - while (functionCallTokenSyntaxLocks.length > 0) - functionCallTokenSyntaxLocks.shift()!.dispose(); - - functionSyntaxStartDetector.clearInProgressStops(); - functionSyntaxStartDetector.clearTriggeredStops(); - - functionSyntaxEndDetector.clearInProgressStops(); - functionSyntaxEndDetector.clearTriggeredStops(); - } - } - - if (!inFunctionEvaluationMode && functionsEnabled && functionsGrammar != null && - functionSyntaxStartDetector.hasTriggeredStops && functionsEvaluationState != null - ) { - inFunctionEvaluationMode = true; - functionCallTokenSyntaxLocks.push(queuedTokenRelease.createTextIndexLock(0)); - - stopGenerationDetector.clearTriggeredStops(); - stopGenerationDetector.clearInProgressStops(); - customStopGenerationTriggersDetector.clearTriggeredStops(); - customStopGenerationTriggersDetector.clearInProgressStops(); - - pendingTokens.push(...streamRegulator.popFreeChunkTokens()); - - const triggeredStops = functionSyntaxStartDetector.getTriggeredStops(); - const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk(model.tokenizer); - - const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger( - triggeredStops, - partiallyFreeTokens, - model.tokenizer - ); - pendingTokens.push(...queuedTokensBeforeStopTrigger); - - const [firstRemainingGenerationAfterStop] = triggeredStops - .map((stopTrigger) => stopTrigger.remainingGenerations) - .filter((remainingGenerations) => remainingGenerations.length > 0) - .flat(1); - - const remainingTextAfterStop = - (firstRemainingGenerationAfterStop == null || firstRemainingGenerationAfterStop.length === 0) - ? "" - : typeof firstRemainingGenerationAfterStop === "string" - ? firstRemainingGenerationAfterStop - : model.detokenize(firstRemainingGenerationAfterStop); - - functionCallTokens.push(...model.tokenize(this._chatWrapper.settings.functions.call.prefix, false, "trimLeadingSpace")); - - for (const functionCallToken of functionCallTokens) - context._acceptTokenOnGrammarEvaluationState(functionsEvaluationState, functionCallToken); - - // these tokens have to be verified that they match the function calling syntax grammar before they can be accepted, - // or the context state should be modified to not include the incompatible tokens - const remainingTextTokens = model.tokenize(remainingTextAfterStop, false, "trimLeadingSpace"); - let unfitTokens: Token[] = []; - - for (let i = 0; i < remainingTextTokens.length; i++) { - const remainingToken = remainingTextTokens[i]; - const canBeNextToken = context._canBeNextTokenForGrammarEvaluationState( - functionsEvaluationState, - remainingToken - ); - - if (!canBeNextToken) { - unfitTokens = remainingTextTokens.slice(i); - break; - } - - context._acceptTokenOnGrammarEvaluationState(functionsEvaluationState, remainingToken); - functionCallTokens.push(remainingToken); - } - - if (unfitTokens.length > 0) { - const unfitTokensText = model.detokenize(unfitTokens); // the current token text must end with it - const currentTokenText = queuedTokenRelease.text; - let replacementTokens: Token[]; - - if (!currentTokenText.endsWith(unfitTokensText)) { - console.warn(getConsoleLogPrefix() + "The current token text does not end with the unfit function call syntax tokens text"); - replacementTokens = remainingTextTokens.slice(0, -unfitTokens.length); - } else { - const newCurrentTokensText = currentTokenText.slice(0, -unfitTokensText.length); - replacementTokens = model.tokenize(newCurrentTokensText, false, "trimLeadingSpace"); - } - - if (replacementTokens.length > 0) { - replacementToken = replacementTokens[0]; - queuedTokenRelease.modifyTokensAndText(replacementTokens, model.detokenize([replacementToken])); - } - } - } else if (inFunctionEvaluationMode) { - functionCallTokens.push(...tokens); - functionCallTokenSyntaxLocks.push(queuedTokenRelease.createTextIndexLock(0)); - functionSyntaxEndDetector.recordGeneration({text, tokens, queuedTokenRelease}); - } - - if (inFunctionEvaluationMode && functionSyntaxEndDetector.hasTriggeredStops && functionsGrammar != null) { - const functionCallText = model.detokenize(functionCallTokens); - const functionCall = functionsGrammar.parseFunctionCall(functionCallText); - - let modelResponse = model.detokenize(res); - let contextWindowModelResponse = model.detokenize(contextWindowsRes); - - if (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) { - modelResponse = modelResponse.trimEnd(); - contextWindowModelResponse = contextWindowModelResponse.trimEnd(); - } - - return { - response: modelResponse, - lastEvaluation: { - contextWindow: setLastModelTextResponseInChatHistory( - lastContextWindowHistory, - contextWindowLastModelResponse + contextWindowModelResponse - ), - cleanHistory: setLastModelTextResponseInChatHistory( - resolvedHistory, - lastModelResponse + modelResponse - ), - contextShiftMetadata: lastHistoryCompressionMetadata - }, - - // prevent infinite TS type instantiation - functionCall: functionCall satisfies LlamaChatResponseFunctionCall> as any, - - metadata: { - stopReason: "functionCall" - } - }; - } - - if (!inFunctionEvaluationMode) { - stopGenerationDetector.recordGeneration({text, tokens, queuedTokenRelease}); - customStopGenerationTriggersDetector.recordGeneration({text, tokens, queuedTokenRelease}); - } - - pendingTokens.push(...streamRegulator.popFreeChunkTokens()); - - removeFoundStartIgnoreTextsFromPendingTokens(); - - if (stopGenerationDetector.hasTriggeredStops || customStopGenerationTriggersDetector.hasTriggeredStops || - model.isEogToken(token) - ) { - stopGenerationDetector.clearInProgressStops(); - customStopGenerationTriggersDetector.clearInProgressStops(); - pendingTokens.push(...streamRegulator.popFreeChunkTokens()); - - const triggeredStops = stopGenerationDetector.hasTriggeredStops - ? stopGenerationDetector.getTriggeredStops() - : customStopGenerationTriggersDetector.getTriggeredStops(); - - const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk(model.tokenizer); - - const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger( - triggeredStops, - partiallyFreeTokens, - model.tokenizer - ); - pendingTokens.push(...queuedTokensBeforeStopTrigger); - - const [firstRemainingGenerationAfterStop] = triggeredStops - .map((stopTrigger) => stopTrigger.remainingGenerations) - .filter((remainingGenerations) => remainingGenerations.length > 0) - .flat(1); - - removeFoundStartIgnoreTextsFromPendingTokens(); - - if (pendingTokens.length > 0) - onToken?.(pendingTokens.slice()); - - res.push(...pendingTokens); - contextWindowsRes.push(...pendingTokens); - pendingTokens.length = 0; - - let modelResponse = model.detokenize(res); - let contextWindowModelResponse = model.detokenize(contextWindowsRes); - - if (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) { - modelResponse = modelResponse.trimEnd(); - contextWindowModelResponse = contextWindowModelResponse.trimEnd(); - } - - const lastEvaluation = { - contextWindow: setLastModelTextResponseInChatHistory( - lastContextWindowHistory, - contextWindowLastModelResponse + contextWindowModelResponse - ), - cleanHistory: setLastModelTextResponseInChatHistory( - resolvedHistory, - lastModelResponse + modelResponse - ), - contextShiftMetadata: lastHistoryCompressionMetadata - }; - const isEogToken = model.isEogToken(token); - - if (isEogToken || stopGenerationDetector.hasTriggeredStops) { - return { - response: modelResponse, - lastEvaluation, - metadata: { - remainingGenerationAfterStop: firstRemainingGenerationAfterStop, - stopReason: isEogToken - ? "eogToken" - : "stopGenerationTrigger" - } - }; - } - - return { - response: modelResponse, - lastEvaluation, - metadata: { - remainingGenerationAfterStop: firstRemainingGenerationAfterStop, - stopReason: "customStopTrigger", - customStopTrigger: triggeredStops[0].stopTrigger - } - }; - } - - const maxTokensTriggered = maxTokens != null && maxTokens > 0 && generatedTokens >= maxTokens; - - if (res.length === 0) { - ignoreStartTextDetector.clearInProgressStops(); - ignoreStartTextDetector.clearTriggeredStops(); + const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger(); + if (maxTokensTriggerRes != null) + return maxTokensTriggerRes; - ignoreStartTextDetector.recordGeneration({ - text: model.detokenize(pendingTokens), - tokens: pendingTokens - }); - } - - if (pendingTokens.length > 0 && (maxTokensTriggered || !ignoreStartTextDetector.hasInProgressStops)) { - removeFoundStartIgnoreTextsFromPendingTokens(); - - if (pendingTokens.length > 0) { - onToken?.(pendingTokens.slice()); - res.push(...pendingTokens); - contextWindowsRes.push(...pendingTokens); - pendingTokens.length = 0; - } - } - - if (maxTokensTriggered) { - let modelResponse = model.detokenize(res); - let contextWindowModelResponse = model.detokenize(contextWindowsRes); - - if (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) { - modelResponse = modelResponse.trimEnd(); - contextWindowModelResponse = contextWindowModelResponse.trimEnd(); - } - - return { - response: modelResponse, - lastEvaluation: { - contextWindow: setLastModelTextResponseInChatHistory( - lastContextWindowHistory, - contextWindowLastModelResponse + contextWindowModelResponse - ), - cleanHistory: setLastModelTextResponseInChatHistory( - resolvedHistory, - lastModelResponse + modelResponse - ), - contextShiftMetadata: lastHistoryCompressionMetadata - }, - metadata: { - stopReason: "maxTokens" - } - }; - } - - if (this._sequence.nextTokenIndex >= context.contextSize - 1) { - shouldContextShift = true; + if (generateResponseState.updateShouldContextShift()) break; - } - if (signal?.aborted && stopOnAbortSignal) { - if (res.length === 0) - throw signal.reason; - - let modelResponse = model.detokenize(res); - let contextWindowModelResponse = model.detokenize(contextWindowsRes); - - if (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) { - modelResponse = modelResponse.trimEnd(); - contextWindowModelResponse = contextWindowModelResponse.trimEnd(); - } - - return { - response: modelResponse, - lastEvaluation: { - contextWindow: setLastModelTextResponseInChatHistory( - lastContextWindowHistory, - contextWindowLastModelResponse + contextWindowModelResponse - ), - cleanHistory: setLastModelTextResponseInChatHistory( - resolvedHistory, - lastModelResponse + modelResponse - ), - contextShiftMetadata: lastHistoryCompressionMetadata - }, - metadata: { - stopReason: "abort" - } - }; - } - currentIteration = await evaluationIterator.next(replacementToken); + const abortRes = generateResponseState.handleAbortTrigger(); + if (abortRes != null) + return abortRes; } - } finally { - await evaluationIterator.return(); - } - isFirstEvaluation = false; + generateResponseState.isFirstEvaluation = false; - if (shouldContextShift) - continue; + if (generateResponseState.shouldContextShift) + continue; - break; - } + break; + } - throw new Error("The context size is too small to generate a response"); + throw new Error("The context size is too small to generate a response"); + } finally { + generateResponseState.dispose(); + } } } @@ -1303,3 +745,886 @@ async function getContextWindow({ disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [] }; } + +class GenerateResponseState { + private readonly llamaChat: LlamaChat; + private readonly chatWrapper: ChatWrapper; + + private readonly history: ChatHistoryItem[]; + private readonly onToken: LLamaChatGenerateResponseOptions["onToken"]; + private readonly signal: LLamaChatGenerateResponseOptions["signal"]; + private readonly stopOnAbortSignal: LLamaChatGenerateResponseOptions["stopOnAbortSignal"]; + private readonly maxTokens: LLamaChatGenerateResponseOptions["maxTokens"]; + private readonly temperature: LLamaChatGenerateResponseOptions["temperature"]; + private readonly minP: LLamaChatGenerateResponseOptions["minP"]; + private readonly topK: LLamaChatGenerateResponseOptions["topK"]; + private readonly topP: LLamaChatGenerateResponseOptions["topP"]; + private readonly grammar: LLamaChatGenerateResponseOptions["grammar"]; + private readonly trimWhitespaceSuffix: LLamaChatGenerateResponseOptions["trimWhitespaceSuffix"]; + private readonly tokenBias: LLamaChatGenerateResponseOptions["tokenBias"]; + private readonly evaluationPriority: LLamaChatGenerateResponseOptions["evaluationPriority"]; + private readonly functions: LLamaChatGenerateResponseOptions["functions"]; + private readonly documentFunctionParams: LLamaChatGenerateResponseOptions["documentFunctionParams"]; + private readonly contextShift: LLamaChatGenerateResponseOptions["contextShift"]; + private readonly customStopTriggers: LLamaChatGenerateResponseOptions["customStopTriggers"]; + private readonly lastEvaluationContextWindowHistory: Exclude["lastEvaluationContextWindow"], undefined>["history"]; + private readonly minimumOverlapPercentageToPreventContextShift: Exclude["lastEvaluationContextWindow"], undefined>["minimumOverlapPercentageToPreventContextShift"], undefined>; + + public readonly functionsEnabled: boolean; + private readonly repeatPenaltyEnabled: boolean; + private readonly resolvedContextShift: Required; + private readonly resolvedRepeatPenalty: LLamaContextualRepeatPenalty & { + lastTokens: number + }; + private readonly lastModelResponse: string; + private readonly grammarEvaluationState: LlamaGrammarEvaluationState | undefined; + private functionsGrammar: FunctionCallGrammar> | undefined; + private functionsEvaluationState: LlamaGrammarEvaluationState | undefined; + + private readonly streamRegulator = new TokenStreamRegulator(); + private readonly stopGenerationDetector = new StopGenerationDetector(); + private readonly customStopGenerationTriggersDetector = new StopGenerationDetector(); + private readonly functionSyntaxStartDetector = new StopGenerationDetector(); + private readonly functionSyntaxEndDetector = new StopGenerationDetector(); + private readonly disengageInitiallyEngagedFunctionMode = new StopGenerationDetector(); + private readonly ignoreStartTextDetector = new StopGenerationDetector(); + private readonly locksToReleaseOnValidGeneration: QueuedTokenReleaseLock[] = []; + private readonly functionCallTokenSyntaxLocks: QueuedTokenReleaseLock[] = []; + + public resolvedHistory: ChatHistoryItem[]; + + public readonly res: Token[] = []; + public readonly pendingTokens: Token[] = []; + public ignoredStartTextTokens: Token[] = []; + public readonly functionCallTokens: Token[] = []; + + public generatedTokens = 0; + public isFirstEvaluation = true; + public inFunctionEvaluationMode = false; + public initiallyEngagedFunctionMode = false; + public lastContextWindowHistory: ChatHistoryItem[]; + public lastHistoryCompressionMetadata: object | null | undefined; + + // context shift loop + public shouldContextShift = false; + public queuedChunkTokens: Token[] = []; + + public contextWindowHistory: ChatHistoryItem[] = []; + public stopGenerationTriggers: LlamaText[] = []; + public contextWindowTokens: Token[] = []; + public newResolvedHistory: ChatHistoryItem[] = []; + public newHistoryCompressionMetadata: object | null | undefined = undefined; + public ignoreStartText: LlamaText[] = []; + public functionCallInitiallyEngaged: boolean = false; + public disengageInitiallyEngagedFunctionCall: LlamaText[] = []; + + public tokens: Token[] = []; + public contextWindowLastModelResponse: string = ""; + public contextWindowsRes: Token[] = []; + + // token evaluation loop + public evaluationIterator?: AsyncGenerator; + public currentIteration?: IteratorResult; + public currentIterationReplacementToken?: Token; + public currentToken?: Token; + public currentTokens: Token[] = []; + public currentText: string = ""; + public currentQueuedTokenRelease?: QueuedTokenRelease; + + public constructor( + llamaChat: LlamaChat, + chatWrapper: ChatWrapper, + history: ChatHistoryItem[], + { + onToken, + signal, + stopOnAbortSignal = false, + maxTokens, + temperature, + minP, + topK, + topP, + grammar, + trimWhitespaceSuffix = false, + repeatPenalty = {}, + tokenBias, + evaluationPriority = 5, + functions, + documentFunctionParams, + contextShift = defaultContextShiftOptions, + customStopTriggers, + lastEvaluationContextWindow: { + history: lastEvaluationContextWindowHistory, + minimumOverlapPercentageToPreventContextShift = 0.5 + } = {} + }: LLamaChatGenerateResponseOptions = {} + ) { + this.llamaChat = llamaChat; + this.chatWrapper = chatWrapper; + + this.history = history; + this.onToken = onToken; + this.signal = signal; + this.stopOnAbortSignal = stopOnAbortSignal; + this.maxTokens = maxTokens; + this.temperature = temperature; + this.minP = minP; + this.topK = topK; + this.topP = topP; + this.grammar = grammar; + this.trimWhitespaceSuffix = trimWhitespaceSuffix; + this.tokenBias = tokenBias; + this.evaluationPriority = evaluationPriority; + this.functions = functions; + this.documentFunctionParams = documentFunctionParams; + this.contextShift = contextShift; + this.customStopTriggers = customStopTriggers; + this.lastEvaluationContextWindowHistory = lastEvaluationContextWindowHistory; + this.minimumOverlapPercentageToPreventContextShift = minimumOverlapPercentageToPreventContextShift; + + this.functionsEnabled = (this.functions != null && Object.keys(this.functions).length > 0); + + if (this.grammar != null && this.functionsEnabled) + throw new Error("Using both grammar and functions is not supported yet"); + + if (this.signal?.aborted) + throw this.signal.reason; + + if (this.llamaChat.disposed) + throw new DisposedError(); + + this.resolvedHistory = this.llamaChat.sequence.isLoadedToMemory + ? this.history.slice() + : this.history.map(removeRawFromHistoryItem); + this.resolvedContextShift = { + ...defaultContextShiftOptions, + ...removeNullFields(this.contextShift) + }; + this.resolvedRepeatPenalty = repeatPenalty === false + ? {lastTokens: 0} + : { + ...(repeatPenalty ?? {}), + lastTokens: repeatPenalty?.lastTokens ?? defaultRepeatPenaltyLastTokens + }; + this.lastModelResponse = getLastTextModelResponseFromChatHistory(this.resolvedHistory); + this.repeatPenaltyEnabled = this.resolvedRepeatPenalty.lastTokens > 0; + this.grammarEvaluationState = this.grammar != null + ? new LlamaGrammarEvaluationState({grammar: this.grammar}) + : undefined; + this.functionsGrammar = this.functionsEnabled + ? new FunctionCallGrammar(this.llamaChat.model._llama, this.functions as NonNullable, this.chatWrapper, false) + : undefined; + this.functionsEvaluationState = (this.functionsEnabled && this.functionsGrammar != null) + ? new LlamaGrammarEvaluationState({ + grammar: this.functionsGrammar + }) + : undefined; + + this.lastContextWindowHistory = this.resolvedHistory; + this.lastHistoryCompressionMetadata = this.resolvedContextShift; + + if (this.customStopTriggers != null) + StopGenerationDetector.resolveStopTriggers(this.customStopTriggers, this.llamaChat.model.tokenizer) + .map((stopTrigger) => this.customStopGenerationTriggersDetector.addStopTrigger(stopTrigger)); + + if (this.grammar != null) + StopGenerationDetector.resolveStopTriggers(this.grammar.stopGenerationTriggers, this.llamaChat.model.tokenizer) + .map((stopTrigger) => this.stopGenerationDetector.addStopTrigger(stopTrigger)); + + if (this.functions != null && Object.keys(this.functions).length > 0) + this.functionSyntaxStartDetector.addStopTrigger([this.chatWrapper.settings.functions.call.prefix]); + + this.getPenaltyTokens = this.getPenaltyTokens.bind(this); + } + + public dispose() { + + } + + public [Symbol.dispose]() { + this.dispose(); + } + + public ensureLastHistoryItemIsModel() { + if (this.resolvedHistory.length === 0 || this.resolvedHistory[this.resolvedHistory.length - 1].type !== "model") + this.resolvedHistory.push({ + type: "model", + response: [] + }); + } + + public ensureLastHistoryItemIsUser() { + if (this.resolvedHistory.length === 0 || this.resolvedHistory[this.resolvedHistory.length - 1].type !== "user") + this.resolvedHistory.push({ + type: "user", + text: "" + }); + } + + public ensureNotAborted() { + if (this.signal?.aborted && (!this.stopOnAbortSignal || this.res.length === 0)) + throw this.signal.reason; + + if (this.llamaChat.disposed) + throw new DisposedError(); + } + + public getPenaltyTokens() { + if (this.llamaChat.disposed) + throw new DisposedError(); + + let punishTokens = this.res.slice(-this.resolvedRepeatPenalty.lastTokens); + + if (this.resolvedRepeatPenalty.punishTokensFilter != null) + punishTokens = this.resolvedRepeatPenalty.punishTokensFilter(punishTokens); + + if (this.resolvedRepeatPenalty.penalizeNewLine == null || !this.resolvedRepeatPenalty.penalizeNewLine) { + const nlToken = this.llamaChat.model.tokens.nl; + + if (nlToken != null) + punishTokens = punishTokens.filter(token => token !== nlToken); + } + + return punishTokens; + } + + public getResolvedHistoryWithCurrentModelResponse() { + if (this.res.length === 0) + return this.resolvedHistory; + + let modelResponse = this.llamaChat.model.detokenize(this.res); + + if (this.grammar?.trimWhitespaceSuffix || this.trimWhitespaceSuffix) + modelResponse = modelResponse.trimEnd(); + + if (modelResponse === "") + return this.resolvedHistory; + + return setLastModelTextResponseInChatHistory( + this.resolvedHistory, + this.lastModelResponse + modelResponse + ); + } + + public removeFoundStartIgnoreTextsFromPendingTokens() { + if (this.res.length === 0 && this.pendingTokens.length > 0) { + this.ignoreStartTextDetector.clearInProgressStops(); + this.ignoreStartTextDetector.clearTriggeredStops(); + + let mostExhaustiveTriggeredStops: ReturnType | null = null; + + for (let i = 0; i < this.pendingTokens.length; i++) { + this.ignoreStartTextDetector.recordGeneration({ + text: this.llamaChat.model.detokenize([this.pendingTokens[i]]), + tokens: [this.pendingTokens[i]], + startNewChecks: i === 0 + }); + + if (this.ignoreStartTextDetector.hasTriggeredStops) { + mostExhaustiveTriggeredStops = this.ignoreStartTextDetector.getTriggeredStops(); + this.ignoreStartTextDetector.clearTriggeredStops(); + } else if (!this.ignoreStartTextDetector.hasInProgressStops) + break; + } + + if (mostExhaustiveTriggeredStops != null) { + const [mostExhaustiveTriggeredStop] = mostExhaustiveTriggeredStops; + + if (mostExhaustiveTriggeredStop != null) { + this.ignoredStartTextTokens = mostExhaustiveTriggeredStop.stopTrigger + .map((stopTrigger) => { + if (typeof stopTrigger === "string") + return this.llamaChat.model.tokenize(stopTrigger, false, "trimLeadingSpace"); + else + return [stopTrigger]; + }) + .flat(1); + + const newPendingTokens = mostExhaustiveTriggeredStop.remainingGenerations + .map((generation) => { + if (typeof generation === "string") + return this.llamaChat.model.tokenize(generation, false, "trimLeadingSpace"); + else + return generation; + }) + .flat(1); + this.pendingTokens.length = 0; + this.pendingTokens.push(...newPendingTokens); + } + } + } + } + + public startTokenLoop() { + this.ensureNotAborted(); + this.shouldContextShift = false; + this.queuedChunkTokens = this.streamRegulator.getAllQueuedChunkTokens(); + } + + public async loadContextWindow() { + const { + history: contextWindowHistory, + stopGenerationTriggers, + tokens: contextWindowTokens, + newResolvedHistory, + newHistoryCompressionMetadata, + ignoreStartText, + functionCallInitiallyEngaged, + disengageInitiallyEngagedFunctionCall + } = await getContextWindow({ + resolvedHistory: this.getResolvedHistoryWithCurrentModelResponse(), + resolvedContextShift: this.resolvedContextShift, + lastHistoryCompressionMetadata: this.lastHistoryCompressionMetadata, + pendingTokensCount: this.ignoredStartTextTokens.length + this.pendingTokens.length + this.queuedChunkTokens.length, + isFirstEvaluation: this.isFirstEvaluation, + chatWrapper: this.chatWrapper, + lastEvaluationContextWindowHistory: this.lastEvaluationContextWindowHistory, + minimumOverlapPercentageToPreventContextShift: this.minimumOverlapPercentageToPreventContextShift, + sequence: this.llamaChat.sequence, + minFreeContextTokens: 1, + functions: this.functionsEnabled ? this.functions : undefined, + documentFunctionParams: this.documentFunctionParams + }); + + this.contextWindowHistory = contextWindowHistory; + this.stopGenerationTriggers = stopGenerationTriggers; + this.contextWindowTokens = contextWindowTokens; + this.newResolvedHistory = newResolvedHistory; + this.newHistoryCompressionMetadata = newHistoryCompressionMetadata; + this.ignoreStartText = ignoreStartText; + this.functionCallInitiallyEngaged = functionCallInitiallyEngaged; + this.disengageInitiallyEngagedFunctionCall = disengageInitiallyEngagedFunctionCall; + + this.ensureNotAborted(); + + this.tokens = [...this.contextWindowTokens, ...this.ignoredStartTextTokens, ...this.pendingTokens, ...this.queuedChunkTokens]; + this.resolvedHistory = this.newResolvedHistory; + this.lastHistoryCompressionMetadata = this.newHistoryCompressionMetadata; + this.lastContextWindowHistory = this.contextWindowHistory; + this.contextWindowLastModelResponse = getLastTextModelResponseFromChatHistory(this.contextWindowHistory); + this.contextWindowsRes = []; + } + + public addIgnoreStartTextTriggersFromChatWrapper() { + StopGenerationDetector.resolveStopTriggers(this.ignoreStartText, this.llamaChat.model.tokenizer) + .map((stopTrigger) => this.ignoreStartTextDetector.addStopTrigger(stopTrigger)); + } + + public addFunctionSyntaxEndTriggersFromFunctionsGrammar() { + if (this.functionsGrammar != null) + StopGenerationDetector.resolveStopTriggers(this.functionsGrammar.stopGenerationTriggers, this.llamaChat.model.tokenizer) + .map((stopTrigger) => this.functionSyntaxEndDetector.addStopTrigger(stopTrigger)); + } + + public addStopGenerationTriggersFromChatWrapper() { + StopGenerationDetector.resolveStopTriggers(this.stopGenerationTriggers, this.llamaChat.model.tokenizer) + .map((stopTrigger) => this.stopGenerationDetector.addStopTrigger(stopTrigger)); + } + + public initFunctions() { + this.initiallyEngagedFunctionMode = this.functionCallInitiallyEngaged; + StopGenerationDetector.resolveStopTriggers(this.disengageInitiallyEngagedFunctionCall, this.llamaChat.model.tokenizer) + .map((stopTrigger) => this.disengageInitiallyEngagedFunctionMode.addStopTrigger(stopTrigger)); + + if (this.initiallyEngagedFunctionMode) { + this.inFunctionEvaluationMode = true; + this.functionsGrammar = new FunctionCallGrammar( + this.llamaChat.model._llama, + this.functions as NonNullable, + this.chatWrapper, + true + ); + this.functionsEvaluationState = new LlamaGrammarEvaluationState({ + grammar: this.functionsGrammar + }); + } + } + + public async alignCurrentSequenceStateWithCurrentTokens() { + let {firstDifferentIndex} = this.llamaChat.sequence.compareContextTokens(this.tokens); + + // we need to decode at least one token to generate a response + if (firstDifferentIndex === this.tokens.length && firstDifferentIndex > 0) + firstDifferentIndex -= 1; + + this.tokens.splice(0, firstDifferentIndex); + + if (firstDifferentIndex < this.llamaChat.sequence.nextTokenIndex) { + await this.llamaChat.sequence.eraseContextTokenRanges([{ + start: firstDifferentIndex, + end: this.llamaChat.sequence.nextTokenIndex + }]); + this.ensureNotAborted(); + } + } + + public async createNewEvaluationIterator() { + if (this.evaluationIterator != null) + await this.evaluationIterator.return(); + + this.currentIterationReplacementToken = undefined; + this.evaluationIterator = this.llamaChat.sequence.evaluate(this.tokens, removeNullFields({ + temperature: this.temperature, + minP: this.minP, + topK: this.topK, + topP: this.topP, + grammarEvaluationState: () => { + if (this.inFunctionEvaluationMode) + return this.functionsEvaluationState; + + return this.grammarEvaluationState; + }, + repeatPenalty: !this.repeatPenaltyEnabled ? undefined : { + punishTokens: this.getPenaltyTokens, + penalty: this.resolvedRepeatPenalty.penalty, + frequencyPenalty: this.resolvedRepeatPenalty.frequencyPenalty, + presencePenalty: this.resolvedRepeatPenalty.presencePenalty + }, + tokenBias: this.tokenBias, + evaluationPriority: this.evaluationPriority, + yieldEogToken: true + })); + } + + public async iterateEvaluation() { + this.currentIteration = await this.evaluationIterator?.next(this.currentIterationReplacementToken); + this.currentIterationReplacementToken = undefined; + + this.ensureNotAborted(); + this.generatedTokens++; + + if (this.currentIteration != null && this.currentIteration?.done !== true) { + this.currentToken = this.currentIteration.value; + this.currentTokens = [this.currentToken]; + this.currentText = this.llamaChat.model.detokenize(this.currentTokens); + this.currentQueuedTokenRelease = this.streamRegulator.addChunk({ + tokens: this.currentTokens, + text: this.currentText + }); + + return true; + } + + return false; + } + + public waitOnPartialCharactersOrWhiteSpaceTokens() { + if (this.currentText === UNKNOWN_UNICODE_CHAR || ( + (this.grammar?.trimWhitespaceSuffix || this.trimWhitespaceSuffix) && this.currentText?.trim() === "" + )) { + if (this.currentQueuedTokenRelease != null) + this.locksToReleaseOnValidGeneration.push(this.currentQueuedTokenRelease.createTextIndexLock(0)); + } else { + while (this.locksToReleaseOnValidGeneration.length > 0) + this.locksToReleaseOnValidGeneration.shift()!.dispose(); + } + } + + public trackGenerationForDisengageInitiallyEngagedFunctionMode() { + if (this.initiallyEngagedFunctionMode) + this.disengageInitiallyEngagedFunctionMode.recordGeneration({ + text: this.currentText, + tokens: this.currentTokens, + startNewChecks: this.generatedTokens === 1 + }); + } + + public trackFunctionSyntaxStart() { + this.functionSyntaxStartDetector.recordGeneration({ + text: this.currentText, + tokens: this.currentTokens, + queuedTokenRelease: this.currentQueuedTokenRelease + }); + } + + public handleInitiallyEngagedFunctionModeFunctionDetection() { + if (this.initiallyEngagedFunctionMode && this.disengageInitiallyEngagedFunctionMode.hasTriggeredStops) { + this.initiallyEngagedFunctionMode = false; + + let shouldStopFunctionEvaluationMode = !this.functionSyntaxStartDetector.hasTriggeredStops; + + if (!shouldStopFunctionEvaluationMode && this.functionsEnabled && this.functionsGrammar != null) { + const functionCallText = this.llamaChat.model.detokenize([...this.functionCallTokens, ...this.currentTokens]); + + try { + const functionName = this.functionsGrammar.parseFunctionNameFromPartialCall(functionCallText, { + enableInternalBuiltinFunctions: true, + initialFunctionCallEngaged: true + }); + + const internalBuiltinFunctions = + this.chatWrapper.getInternalBuiltinFunctions({initialFunctionCallEngaged: true}); + if (internalBuiltinFunctions[functionName] != null) { + shouldStopFunctionEvaluationMode = true; + } + } catch (err) { + if (!(err instanceof LlamaFunctionCallValidationError)) + throw err; + } + } + + if (shouldStopFunctionEvaluationMode) { + this.inFunctionEvaluationMode = false; + this.functionsGrammar = new FunctionCallGrammar( + this.llamaChat.model._llama, + this.functions as NonNullable, + this.chatWrapper, + false + ); + this.functionsEvaluationState = new LlamaGrammarEvaluationState({ + grammar: this.functionsGrammar + }); + + this.functionCallTokens.length = 0; + + while (this.functionCallTokenSyntaxLocks.length > 0) + this.functionCallTokenSyntaxLocks.shift()!.dispose(); + + this.functionSyntaxStartDetector.clearInProgressStops(); + this.functionSyntaxStartDetector.clearTriggeredStops(); + + this.functionSyntaxEndDetector.clearInProgressStops(); + this.functionSyntaxEndDetector.clearTriggeredStops(); + } + } + } + + public handleFunctionSyntax() { + if (this.currentQueuedTokenRelease != null && !this.inFunctionEvaluationMode && this.functionsEnabled && + this.functionsGrammar != null && this.functionSyntaxStartDetector.hasTriggeredStops && this.functionsEvaluationState != null + ) { + this.inFunctionEvaluationMode = true; + this.functionCallTokenSyntaxLocks.push(this.currentQueuedTokenRelease.createTextIndexLock(0)); + + this.stopGenerationDetector.clearTriggeredStops(); + this.stopGenerationDetector.clearInProgressStops(); + this.customStopGenerationTriggersDetector.clearTriggeredStops(); + this.customStopGenerationTriggersDetector.clearInProgressStops(); + + this.pendingTokens.push(...this.streamRegulator.popFreeChunkTokens()); + + const triggeredStops = this.functionSyntaxStartDetector.getTriggeredStops(); + const partiallyFreeTokens = this.streamRegulator.getPartiallyFreeChunk(this.llamaChat.model.tokenizer); + + const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger( + triggeredStops, + partiallyFreeTokens, + this.llamaChat.model.tokenizer + ); + this.pendingTokens.push(...queuedTokensBeforeStopTrigger); + + const [firstRemainingGenerationAfterStop] = triggeredStops + .map((stopTrigger) => stopTrigger.remainingGenerations) + .filter((remainingGenerations) => remainingGenerations.length > 0) + .flat(1); + + const remainingTextAfterStop = + (firstRemainingGenerationAfterStop == null || firstRemainingGenerationAfterStop.length === 0) + ? "" + : typeof firstRemainingGenerationAfterStop === "string" + ? firstRemainingGenerationAfterStop + : this.llamaChat.model.detokenize(firstRemainingGenerationAfterStop); + + this.functionCallTokens.push(...this.llamaChat.model.tokenize(this.chatWrapper.settings.functions.call.prefix, false, "trimLeadingSpace")); + + for (const functionCallToken of this.functionCallTokens) + this.llamaChat.context._acceptTokenOnGrammarEvaluationState(this.functionsEvaluationState, functionCallToken); + + // these tokens have to be verified that they match the function calling syntax grammar before they can be accepted, + // or the context state should be modified to not include the incompatible tokens + const remainingTextTokens = this.llamaChat.model.tokenize(remainingTextAfterStop, false, "trimLeadingSpace"); + let unfitTokens: Token[] = []; + + for (let i = 0; i < remainingTextTokens.length; i++) { + const remainingToken = remainingTextTokens[i]; + const canBeNextToken = this.llamaChat.context._canBeNextTokenForGrammarEvaluationState( + this.functionsEvaluationState, + remainingToken + ); + + if (!canBeNextToken) { + unfitTokens = remainingTextTokens.slice(i); + break; + } + + this.llamaChat.context._acceptTokenOnGrammarEvaluationState(this.functionsEvaluationState, remainingToken); + this.functionCallTokens.push(remainingToken); + } + + if (unfitTokens.length > 0) { + const unfitTokensText = this.llamaChat.model.detokenize(unfitTokens); // the current token text must end with it + const currentTokenText = this.currentQueuedTokenRelease.text; + let replacementTokens: Token[]; + + if (!currentTokenText.endsWith(unfitTokensText)) { + console.warn(getConsoleLogPrefix() + "The current token text does not end with the unfit function call syntax tokens text"); + replacementTokens = remainingTextTokens.slice(0, -unfitTokens.length); + } else { + const newCurrentTokensText = currentTokenText.slice(0, -unfitTokensText.length); + replacementTokens = this.llamaChat.model.tokenize(newCurrentTokensText, false, "trimLeadingSpace"); + } + + if (replacementTokens.length > 0) { + this.currentIterationReplacementToken = replacementTokens[0]; + this.currentQueuedTokenRelease.modifyTokensAndText( + replacementTokens, + this.llamaChat.model.detokenize([this.currentIterationReplacementToken]) + ); + } + } + } else if (this.inFunctionEvaluationMode) { + this.functionCallTokens.push(...this.currentTokens); + + if (this.currentQueuedTokenRelease != null) + this.functionCallTokenSyntaxLocks.push(this.currentQueuedTokenRelease.createTextIndexLock(0)); + + this.functionSyntaxEndDetector.recordGeneration({ + text: this.currentText, + tokens: this.currentTokens, + queuedTokenRelease: this.currentQueuedTokenRelease + }); + } + } + + public detectFunctionEndSyntax(): LlamaChatResponse | undefined { + if (this.inFunctionEvaluationMode && this.functionSyntaxEndDetector.hasTriggeredStops && this.functionsGrammar != null) { + const functionCallText = this.llamaChat.model.detokenize(this.functionCallTokens); + const functionCall = this.functionsGrammar.parseFunctionCall(functionCallText); + + let modelResponse = this.llamaChat.model.detokenize(this.res); + let contextWindowModelResponse = this.llamaChat.model.detokenize(this.contextWindowsRes); + + if (this.grammar?.trimWhitespaceSuffix || this.trimWhitespaceSuffix) { + modelResponse = modelResponse.trimEnd(); + contextWindowModelResponse = contextWindowModelResponse.trimEnd(); + } + + return { + response: modelResponse, + lastEvaluation: { + contextWindow: setLastModelTextResponseInChatHistory( + this.lastContextWindowHistory, + this.contextWindowLastModelResponse + contextWindowModelResponse + ), + cleanHistory: setLastModelTextResponseInChatHistory( + this.resolvedHistory, + this.lastModelResponse + modelResponse + ), + contextShiftMetadata: this.lastHistoryCompressionMetadata + }, + + // prevent infinite TS type instantiation + functionCall: functionCall satisfies LlamaChatResponseFunctionCall> as any, + + metadata: { + stopReason: "functionCall" + } + }; + } + + return undefined; + } + + public recordStopGenerationEvaluation() { + if (!this.inFunctionEvaluationMode) { + this.stopGenerationDetector.recordGeneration({ + text: this.currentText, + tokens: this.currentTokens, + queuedTokenRelease: this.currentQueuedTokenRelease + }); + this.customStopGenerationTriggersDetector.recordGeneration({ + text: this.currentText, + tokens: this.currentTokens, + queuedTokenRelease: this.currentQueuedTokenRelease + }); + } + } + + public popStreamRegulatorFreeTokens() { + this.pendingTokens.push(...this.streamRegulator.popFreeChunkTokens()); + } + + public handleStopGenerationTrigger(): LlamaChatResponse | undefined { + if (this.stopGenerationDetector.hasTriggeredStops || this.customStopGenerationTriggersDetector.hasTriggeredStops || + this.llamaChat.model.isEogToken(this.currentToken) + ) { + this.stopGenerationDetector.clearInProgressStops(); + this.customStopGenerationTriggersDetector.clearInProgressStops(); + this.pendingTokens.push(...this.streamRegulator.popFreeChunkTokens()); + + const triggeredStops = this.stopGenerationDetector.hasTriggeredStops + ? this.stopGenerationDetector.getTriggeredStops() + : this.customStopGenerationTriggersDetector.getTriggeredStops(); + + const partiallyFreeTokens = this.streamRegulator.getPartiallyFreeChunk(this.llamaChat.model.tokenizer); + + const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger( + triggeredStops, + partiallyFreeTokens, + this.llamaChat.model.tokenizer + ); + this.pendingTokens.push(...queuedTokensBeforeStopTrigger); + + const [firstRemainingGenerationAfterStop] = triggeredStops + .map((stopTrigger) => stopTrigger.remainingGenerations) + .filter((remainingGenerations) => remainingGenerations.length > 0) + .flat(1); + + this.removeFoundStartIgnoreTextsFromPendingTokens(); + + if (this.pendingTokens.length > 0) + this.onToken?.(this.pendingTokens.slice()); + + this.res.push(...this.pendingTokens); + this.contextWindowsRes.push(...this.pendingTokens); + this.pendingTokens.length = 0; + + let modelResponse = this.llamaChat.model.detokenize(this.res); + let contextWindowModelResponse = this.llamaChat.model.detokenize(this.contextWindowsRes); + + if (this.grammar?.trimWhitespaceSuffix || this.trimWhitespaceSuffix) { + modelResponse = modelResponse.trimEnd(); + contextWindowModelResponse = contextWindowModelResponse.trimEnd(); + } + + const lastEvaluation = { + contextWindow: setLastModelTextResponseInChatHistory( + this.lastContextWindowHistory, + this.contextWindowLastModelResponse + contextWindowModelResponse + ), + cleanHistory: setLastModelTextResponseInChatHistory( + this.resolvedHistory, + this.lastModelResponse + modelResponse + ), + contextShiftMetadata: this.lastHistoryCompressionMetadata + }; + const isEogToken = this.llamaChat.model.isEogToken(this.currentToken); + + if (isEogToken || this.stopGenerationDetector.hasTriggeredStops) { + return { + response: modelResponse, + lastEvaluation, + metadata: { + remainingGenerationAfterStop: firstRemainingGenerationAfterStop, + stopReason: isEogToken + ? "eogToken" + : "stopGenerationTrigger" + } + }; + } + + return { + response: modelResponse, + lastEvaluation, + metadata: { + remainingGenerationAfterStop: firstRemainingGenerationAfterStop, + stopReason: "customStopTrigger", + customStopTrigger: triggeredStops[0].stopTrigger + } + }; + } + + return undefined; + } + + public spliceIgnoreStartTextDetectedTokens() { + if (this.res.length === 0) { + this.ignoreStartTextDetector.clearInProgressStops(); + this.ignoreStartTextDetector.clearTriggeredStops(); + + this.ignoreStartTextDetector.recordGeneration({ + text: this.llamaChat.model.detokenize(this.pendingTokens), + tokens: this.pendingTokens + }); + } + } + + public isMaxTokensTriggered() { + return this.maxTokens != null && this.maxTokens > 0 && this.generatedTokens >= this.maxTokens; + } + + public moveFreePendingTokensToRes() { + if (this.pendingTokens.length > 0 && (this.isMaxTokensTriggered() || !this.ignoreStartTextDetector.hasInProgressStops)) { + this.removeFoundStartIgnoreTextsFromPendingTokens(); + + if (this.pendingTokens.length > 0) { + this.onToken?.(this.pendingTokens.slice()); + this.res.push(...this.pendingTokens); + this.contextWindowsRes.push(...this.pendingTokens); + this.pendingTokens.length = 0; + } + } + } + + public handleMaxTokensTrigger(): LlamaChatResponse | undefined { + if (this.isMaxTokensTriggered()) { + let modelResponse = this.llamaChat.model.detokenize(this.res); + let contextWindowModelResponse = this.llamaChat.model.detokenize(this.contextWindowsRes); + + if (this.grammar?.trimWhitespaceSuffix || this.trimWhitespaceSuffix) { + modelResponse = modelResponse.trimEnd(); + contextWindowModelResponse = contextWindowModelResponse.trimEnd(); + } + + return { + response: modelResponse, + lastEvaluation: { + contextWindow: setLastModelTextResponseInChatHistory( + this.lastContextWindowHistory, + this.contextWindowLastModelResponse + contextWindowModelResponse + ), + cleanHistory: setLastModelTextResponseInChatHistory( + this.resolvedHistory, + this.lastModelResponse + modelResponse + ), + contextShiftMetadata: this.lastHistoryCompressionMetadata + }, + metadata: { + stopReason: "maxTokens" + } + }; + } + + return undefined; + } + + public updateShouldContextShift() { + this.shouldContextShift = this.llamaChat.sequence.nextTokenIndex >= this.llamaChat.context.contextSize - 1; + return this.shouldContextShift; + } + + public handleAbortTrigger(): LlamaChatResponse | undefined { + if (this.signal?.aborted && this.stopOnAbortSignal) { + if (this.res.length === 0) + throw this.signal.reason; + + let modelResponse = this.llamaChat.model.detokenize(this.res); + let contextWindowModelResponse = this.llamaChat.model.detokenize(this.contextWindowsRes); + + if (this.grammar?.trimWhitespaceSuffix || this.trimWhitespaceSuffix) { + modelResponse = modelResponse.trimEnd(); + contextWindowModelResponse = contextWindowModelResponse.trimEnd(); + } + + return { + response: modelResponse, + lastEvaluation: { + contextWindow: setLastModelTextResponseInChatHistory( + this.lastContextWindowHistory, + this.contextWindowLastModelResponse + contextWindowModelResponse + ), + cleanHistory: setLastModelTextResponseInChatHistory( + this.resolvedHistory, + this.lastModelResponse + modelResponse + ), + contextShiftMetadata: this.lastHistoryCompressionMetadata + }, + metadata: { + stopReason: "abort" + } + }; + } + + return undefined; + } +} diff --git a/src/evaluator/LlamaModel.ts b/src/evaluator/LlamaModel.ts index 4aea5131..5157158d 100644 --- a/src/evaluator/LlamaModel.ts +++ b/src/evaluator/LlamaModel.ts @@ -362,14 +362,17 @@ export class LlamaModel { } /** Check whether the given token is a special token (a control-type token) */ - public isSpecialToken(token: Token): boolean { + public isSpecialToken(token: Token | undefined): boolean { + if (token == null) + return false; + const tokenType = this.getTokenType(token); return tokenType === GgufMetadataTokenizerTokenType.control; } /** Check whether the given token is an EOG (End Of Generation) token, like EOS or EOT. */ - public isEogToken(token: Token): boolean { + public isEogToken(token: Token | undefined): boolean { if (token == null) return false; From 86e86ac7992e634379ef5585cf60c3263e5aa152 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 24 May 2024 19:26:16 +0300 Subject: [PATCH 02/39] feat: preload prompt and complete a preloaded prompt --- src/evaluator/LlamaChat/LlamaChat.ts | 592 ++++++++++++++---- .../LlamaChatSession/LlamaChatSession.ts | 173 ++++- src/evaluator/LlamaContext/LlamaContext.ts | 2 - src/index.ts | 9 +- .../modelDependent/llama3/chatSession.test.ts | 44 ++ 5 files changed, 705 insertions(+), 115 deletions(-) diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index 4d03b91b..7d0af385 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -1,7 +1,9 @@ -import {DisposeAggregator, DisposedError, EventRelay} from "lifecycle-utils"; +import {DisposeAggregator, DisposedError, EventRelay, withLock} from "lifecycle-utils"; import {ChatWrapper} from "../../ChatWrapper.js"; import {LlamaContextSequence} from "../LlamaContext/LlamaContext.js"; -import {ChatHistoryItem, ChatModelFunctions, ChatModelResponse, LLamaContextualRepeatPenalty, Token, Tokenizer} from "../../types.js"; +import { + ChatHistoryItem, ChatModelFunctions, ChatModelResponse, ChatUserMessage, LLamaContextualRepeatPenalty, Token, Tokenizer +} from "../../types.js"; import {GbnfJsonSchemaToType} from "../../utils/gbnfJson/types.js"; import {LlamaGrammar} from "../LlamaGrammar.js"; import {removeNullFields} from "../../utils/removeNullFields.js"; @@ -90,7 +92,8 @@ export type LLamaChatGenerateResponseOptions = { + /** + * Complete the given user prompt without adding it or the completion to the returned context window. + */ + initialUserPrompt?: string, + + /** + * When a completion already started being generated and then the signal is aborted, + * the generation will stop and the completion will be returned as is instead of throwing an error. + * + * Defaults to `false`. + */ + stopOnAbortSignal?: boolean, + + onToken?: LLamaChatGenerateResponseOptions["onToken"], + signal?: LLamaChatGenerateResponseOptions["signal"], + maxTokens?: LLamaChatGenerateResponseOptions["maxTokens"], + temperature?: LLamaChatGenerateResponseOptions["temperature"], + minP?: LLamaChatGenerateResponseOptions["minP"], + topK?: LLamaChatGenerateResponseOptions["topK"], + topP?: LLamaChatGenerateResponseOptions["topP"], + trimWhitespaceSuffix?: LLamaChatGenerateResponseOptions["trimWhitespaceSuffix"], + repeatPenalty?: LLamaChatGenerateResponseOptions["repeatPenalty"], + tokenBias?: LLamaChatGenerateResponseOptions["tokenBias"], + evaluationPriority?: LLamaChatGenerateResponseOptions["evaluationPriority"], + contextShift?: LLamaChatGenerateResponseOptions["contextShift"], + customStopTriggers?: LLamaChatGenerateResponseOptions["customStopTriggers"], + lastEvaluationContextWindow?: LLamaChatGenerateResponseOptions["lastEvaluationContextWindow"], + + grammar?: LlamaGrammar, + + /** + * Functions are not used by the model here, + * but are used for keeping the instructions given to the model about the functions in the current context state, + * to avoid context shifts. + * + * It's best to provide the same functions that were used for the previous prompt here. + */ + functions?: Functions | ChatModelFunctions, + + /** + * Functions are not used by the model here, + * but are used for keeping the instructions given to the model about the functions in the current context state, + * to avoid context shifts. + * + * It's best to provide the same value that was used for the previous prompt here. + */ + documentFunctionParams?: boolean +}; + export type LLamaChatContextShiftOptions = { /** * The number of tokens to delete from the context window to make space for new ones. @@ -175,12 +230,15 @@ const defaultContextShiftOptions: Required = { lastEvaluationMetadata: null }; const defaultRepeatPenaltyLastTokens = 64; +const defaultTrimWhitespaceSuffix = false; +const defaultEvaluationPriority: EvaluationPriority = 5; export class LlamaChat { /** @internal */ private readonly _chatWrapper: ChatWrapper; /** @internal */ private readonly _disposeAggregator = new DisposeAggregator(); /** @internal */ private readonly _autoDisposeSequence: boolean; + /** @internal */ private readonly _chatLock = {}; /** @internal */ private _sequence: LlamaContextSequence | null; public readonly onDispose = new EventRelay(); @@ -264,13 +322,7 @@ export class LlamaChat { history: ChatHistoryItem[], options: LLamaChatGenerateResponseOptions = {} ): Promise> { - return this._generateResponse(history, options); - } - - /** @internal */ - private async _generateResponse( - history: ChatHistoryItem[], - { + const { onToken, signal, stopOnAbortSignal = false, @@ -280,10 +332,10 @@ export class LlamaChat { topK, topP, grammar, - trimWhitespaceSuffix = false, + trimWhitespaceSuffix = defaultTrimWhitespaceSuffix, repeatPenalty = {}, tokenBias, - evaluationPriority = 5, + evaluationPriority = defaultEvaluationPriority, functions, documentFunctionParams, contextShift = defaultContextShiftOptions, @@ -292,8 +344,8 @@ export class LlamaChat { history: lastEvaluationContextWindowHistory, minimumOverlapPercentageToPreventContextShift = 0.5 } = {} - }: LLamaChatGenerateResponseOptions = {} - ): Promise> { + } = options; + const generateResponseState = new GenerateResponseState( this, this._chatWrapper, @@ -307,7 +359,7 @@ export class LlamaChat { minP, topK, topP, - grammar: grammar as never, + grammar: grammar as undefined, // this is a workaround to allow passing both `functions` and `grammar` trimWhitespaceSuffix, repeatPenalty, tokenBias, @@ -323,77 +375,263 @@ export class LlamaChat { } ); - try { - generateResponseState.ensureLastHistoryItemIsModel(); + if (generateResponseState.grammar != null && generateResponseState.functionsEnabled) + throw new Error("Using both grammar and functions is not supported yet"); + + return await withLock(this._chatLock, "evaluate", signal, async (): Promise> => { + try { + generateResponseState.ensureLastHistoryItemIsModel(); - // eslint-disable-next-line no-constant-condition - while (true) { - generateResponseState.startTokenLoop(); - await generateResponseState.loadContextWindow(); + // eslint-disable-next-line no-constant-condition + while (true) { + generateResponseState.startTokenLoop(); + await generateResponseState.loadContextWindow( + generateResponseState.getResolvedHistoryWithCurrentModelResponse(), + false + ); - if (generateResponseState.generatedTokens === 0) { - generateResponseState.addIgnoreStartTextTriggersFromChatWrapper(); - generateResponseState.addFunctionSyntaxEndTriggersFromFunctionsGrammar(); + if (generateResponseState.generatedTokens === 0) { + generateResponseState.addIgnoreStartTextTriggersFromChatWrapper(); + generateResponseState.addFunctionSyntaxEndTriggersFromFunctionsGrammar(); - if (generateResponseState.functionsEnabled) { - generateResponseState.initFunctions(); + if (generateResponseState.functionsEnabled) { + generateResponseState.initFunctions(); + } } - } - generateResponseState.addStopGenerationTriggersFromChatWrapper(); - await generateResponseState.alignCurrentSequenceStateWithCurrentTokens(); + generateResponseState.addStopGenerationTriggersFromChatWrapper(); + await generateResponseState.alignCurrentSequenceStateWithCurrentTokens(); + + await generateResponseState.createNewEvaluationIterator(); + while (await generateResponseState.iterateEvaluation()) { + generateResponseState.waitOnPartialCharactersOrWhiteSpaceTokens(); + + generateResponseState.trackGenerationForDisengageInitiallyEngagedFunctionMode(); + generateResponseState.trackFunctionSyntaxStart(); - await generateResponseState.createNewEvaluationIterator(); - while (await generateResponseState.iterateEvaluation()) { - generateResponseState.waitOnPartialCharactersOrWhiteSpaceTokens(); + generateResponseState.handleInitiallyEngagedFunctionModeFunctionDetection(); + generateResponseState.handleFunctionSyntax(); - generateResponseState.trackGenerationForDisengageInitiallyEngagedFunctionMode(); - generateResponseState.trackFunctionSyntaxStart(); + const functionEndSyntaxRes = generateResponseState.detectFunctionEndSyntax(); + if (functionEndSyntaxRes != null) + return functionEndSyntaxRes; - generateResponseState.handleInitiallyEngagedFunctionModeFunctionDetection(); - generateResponseState.handleFunctionSyntax(); + generateResponseState.recordStopGenerationEvaluation(); - const functionEndSyntaxRes = generateResponseState.detectFunctionEndSyntax(); - if (functionEndSyntaxRes != null) - return functionEndSyntaxRes; + generateResponseState.popStreamRegulatorFreeTokens(); + generateResponseState.removeFoundStartIgnoreTextsFromPendingTokens(); - generateResponseState.recordStopGenerationEvaluation(); + const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger(); + if (stopGenerationTriggerRes != null) + return stopGenerationTriggerRes; - generateResponseState.popStreamRegulatorFreeTokens(); - generateResponseState.removeFoundStartIgnoreTextsFromPendingTokens(); + generateResponseState.spliceIgnoreStartTextDetectedTokens(); - const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger(); - if (stopGenerationTriggerRes != null) - return stopGenerationTriggerRes; + generateResponseState.moveFreePendingTokensToRes(); - generateResponseState.spliceIgnoreStartTextDetectedTokens(); + const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger(); + if (maxTokensTriggerRes != null) + return maxTokensTriggerRes; - generateResponseState.moveFreePendingTokensToRes(); + if (generateResponseState.updateShouldContextShift()) + break; - const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger(); - if (maxTokensTriggerRes != null) - return maxTokensTriggerRes; + const abortRes = generateResponseState.handleAbortTrigger(); + if (abortRes != null) + return abortRes; + } + + generateResponseState.isFirstEvaluation = false; - if (generateResponseState.updateShouldContextShift()) - break; + if (generateResponseState.shouldContextShift) + continue; - const abortRes = generateResponseState.handleAbortTrigger(); - if (abortRes != null) - return abortRes; + break; } - generateResponseState.isFirstEvaluation = false; + throw new Error("The context size is too small to generate a response"); + } finally { + generateResponseState.dispose(); + } + }); + } - if (generateResponseState.shouldContextShift) - continue; + public async loadChatAndCompleteUserMessage( + history: ChatHistoryItem[], + options: LLamaChatLoadAndCompleteUserMessageOptions = {} + ): Promise { + const { + initialUserPrompt = "", + stopOnAbortSignal = false, + onToken, + signal, + maxTokens = Math.min(256, Math.ceil(this.context.contextSize / 2)), + temperature, + minP, + topK, + topP, + grammar, + trimWhitespaceSuffix = defaultTrimWhitespaceSuffix, + repeatPenalty = {}, + tokenBias, + evaluationPriority = defaultEvaluationPriority, + functions, + documentFunctionParams, + contextShift = defaultContextShiftOptions, + customStopTriggers, + lastEvaluationContextWindow: { + history: lastEvaluationContextWindowHistory, + minimumOverlapPercentageToPreventContextShift = 0.8 + } = {} + } = options; - break; + const generateResponseState = new GenerateResponseState( + this, + this._chatWrapper, + history, + { + onToken, + signal, + stopOnAbortSignal, + maxTokens, + temperature, + minP, + topK, + topP, + grammar: grammar as undefined, // this is a workaround to allow passing both `functions` and `grammar` + trimWhitespaceSuffix, + repeatPenalty, + tokenBias, + evaluationPriority, + functions, + documentFunctionParams, + contextShift, + customStopTriggers, + lastEvaluationContextWindow: { + history: lastEvaluationContextWindowHistory, + minimumOverlapPercentageToPreventContextShift + } } + ); - throw new Error("The context size is too small to generate a response"); - } finally { - generateResponseState.dispose(); - } + return await withLock(this._chatLock, "evaluate", signal, async (): Promise => { + try { + generateResponseState.ensureLastHistoryItemIsUser(); + const lastResolvedHistoryItem = generateResponseState.resolvedHistory[generateResponseState.resolvedHistory.length - 1]; + const initialUserMessage = lastResolvedHistoryItem?.type === "user" + ? lastResolvedHistoryItem.text + : ""; + + // eslint-disable-next-line no-constant-condition + while (true) { + generateResponseState.startTokenLoop(); + const {userTextSuffix} = await generateResponseState.loadContextWindow( + setLastUserTextInChatHistory( + generateResponseState.resolvedHistory, + initialUserMessage + initialUserPrompt + this.model.detokenize(generateResponseState.res) + ), + true + ); + generateResponseState.inFunctionEvaluationMode = false; + + generateResponseState.addStopGenerationTriggersFromChatWrapper(); + + if (userTextSuffix != null && userTextSuffix.values.length > 0) + generateResponseState.stopGenerationDetector.addStopTrigger( + StopGenerationDetector.resolveLlamaTextTrigger(userTextSuffix, this.model.tokenizer) + ); + + await generateResponseState.alignCurrentSequenceStateWithCurrentTokens(); + + if (generateResponseState.maxTokens === 0) { + await generateResponseState.evaluateWithoutGeneratingNewTokens(); + + return { + completion: "", + lastEvaluation: { + contextWindow: setLastUserTextInChatHistory( + generateResponseState.contextWindowHistory, + initialUserMessage + ), + contextShiftMetadata: generateResponseState.lastHistoryCompressionMetadata + }, + metadata: { + stopReason: "maxTokens" + } + }; + } + + await generateResponseState.createNewEvaluationIterator(); + while (await generateResponseState.iterateEvaluation()) { + generateResponseState.waitOnPartialCharactersOrWhiteSpaceTokens(); + + generateResponseState.recordStopGenerationEvaluation(); + + generateResponseState.popStreamRegulatorFreeTokens(); + + const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger(); + if (stopGenerationTriggerRes != null) + return { + completion: stopGenerationTriggerRes.response, + lastEvaluation: { + contextWindow: setLastUserTextInChatHistory( + generateResponseState.contextWindowHistory, + initialUserMessage + ), + contextShiftMetadata: stopGenerationTriggerRes.lastEvaluation.contextShiftMetadata + }, + metadata: stopGenerationTriggerRes.metadata.stopReason === "customStopTrigger" + ? stopGenerationTriggerRes.metadata + : stopGenerationTriggerRes.metadata + }; + + generateResponseState.moveFreePendingTokensToRes(false); + + const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger(); + if (maxTokensTriggerRes != null) + return { + completion: maxTokensTriggerRes.response, + lastEvaluation: { + contextWindow: setLastUserTextInChatHistory( + generateResponseState.contextWindowHistory, + initialUserMessage + ), + contextShiftMetadata: maxTokensTriggerRes.lastEvaluation.contextShiftMetadata + }, + metadata: maxTokensTriggerRes.metadata + }; + + if (generateResponseState.updateShouldContextShift()) + break; + + const abortRes = generateResponseState.handleAbortTrigger(); + if (abortRes != null) + return { + completion: abortRes.response, + lastEvaluation: { + contextWindow: setLastUserTextInChatHistory( + generateResponseState.contextWindowHistory, + initialUserMessage + ), + contextShiftMetadata: abortRes.lastEvaluation.contextShiftMetadata + }, + metadata: abortRes.metadata + }; + } + + generateResponseState.isFirstEvaluation = false; + + if (generateResponseState.shouldContextShift) + continue; + + break; + } + + throw new Error("The context size is too small to generate a completion"); + } finally { + generateResponseState.dispose(); + } + }); } } @@ -429,6 +667,26 @@ export type LlamaChatResponseFunctionCall< raw: string }; +export type LlamaChatLoadAndCompleteUserResponse = { + completion: string, + lastEvaluation: { + /** + * The completion and initial user prompt are not added to this context window result, + * but are loaded to the current context sequence state as tokens + */ + contextWindow: ChatHistoryItem[], + contextShiftMetadata: any + }, + metadata: { + remainingGenerationAfterStop?: string | Token[], + stopReason: "eogToken" | "stopGenerationTrigger" | "maxTokens" | "abort" + } | { + remainingGenerationAfterStop?: string | Token[], + stopReason: "customStopTrigger", + customStopTrigger: (string | Token)[] + } +}; + function removeRawFromHistoryItem(historyItem: Item): Item { if (historyItem.type === "model") { const newHistoryItem: ChatModelResponse = {...historyItem}; @@ -559,6 +817,13 @@ function getLastTextModelResponseFromChatHistory(chatHistory: ChatHistoryItem[]) return ""; } +function getLastUserTextFromChatHistory(chatHistory: ChatHistoryItem[]) { + if (chatHistory.length === 0 || chatHistory[chatHistory.length - 1].type !== "user") + return ""; + + return (chatHistory[chatHistory.length - 1] as ChatUserMessage).text; +} + function setLastModelTextResponseInChatHistory(chatHistory: ChatHistoryItem[], textResponse: string) { const newChatHistory = chatHistory.slice(); if (newChatHistory.length === 0 || newChatHistory[newChatHistory.length - 1].type !== "model") @@ -585,22 +850,96 @@ function setLastModelTextResponseInChatHistory(chatHistory: ChatHistoryItem[], t return newChatHistory; } +function setLastUserTextInChatHistory(chatHistory: ChatHistoryItem[], textResponse: string) { + const newChatHistory = chatHistory.slice(); + if (newChatHistory.length === 0 || newChatHistory[newChatHistory.length - 1].type !== "user") + newChatHistory.push({ + type: "user", + text: "" + }); + + const lastUserItem = newChatHistory[newChatHistory.length - 1] as ChatUserMessage; + const newLastUserItem = {...lastUserItem}; + newChatHistory[newChatHistory.length - 1] = newLastUserItem; + + newLastUserItem.text = textResponse; + + return newChatHistory; +} + +function generateContextText( + endWithUserText: boolean, + chatWrapper: ChatWrapper, + chatHistory: ChatHistoryItem[], + options?: Parameters[1] +): ReturnType { + if (endWithUserText) + return generateContextTextThatEndsWithUserText(chatWrapper, chatHistory, options); + + return chatWrapper.generateContextText(chatHistory, options); +} + +function generateContextTextThatEndsWithUserText( + chatWrapper: ChatWrapper, chatHistory: ChatHistoryItem[], options?: Parameters[1] +): ReturnType & { + userTextSuffix?: LlamaText +} { + const lastUserText = getLastUserTextFromChatHistory(chatHistory); + const randomId = "W" + (Math.random() + .toString(36) + .slice(2)) + "W"; + const {contextText, ...rest} = chatWrapper.generateContextText( + setLastUserTextInChatHistory(chatHistory, lastUserText + randomId), + options + ); + let newContextText = contextText; + + for (let i = 0; i < newContextText.values.length; i++) { + const item = newContextText.values[i]; + if (typeof item !== "string") + continue; + + const randomTextIndex = item.indexOf(randomId); + if (randomTextIndex < 0) + continue; + + const newValue = item.slice(0, randomTextIndex); + newContextText = LlamaText([ + ...newContextText.values.slice(0, i), + newValue + ]); + return { + contextText: newContextText, + userTextSuffix: LlamaText([ + item.slice(randomTextIndex + randomId.length), + ...newContextText.values.slice(i + 1) + ]), + ...rest + }; + } + + throw new Error("The random ID was not found in the context text. " + + `There might be an issue with the chat wrapper "${chatWrapper.wrapperName}" ` + + "where not all user messages are properly added to the the result LlamaText" + ); +} + async function getContextWindow({ resolvedHistory, resolvedContextShift, lastHistoryCompressionMetadata, pendingTokensCount = 0, isFirstEvaluation, chatWrapper, lastEvaluationContextWindowHistory, minimumOverlapPercentageToPreventContextShift, - sequence, minFreeContextTokens = 1, functions, documentFunctionParams + sequence, minFreeContextTokens = 1, functions, documentFunctionParams, endWithUserText }: { resolvedHistory: ChatHistoryItem[], resolvedContextShift: Required, lastHistoryCompressionMetadata: object | null | undefined, pendingTokensCount: number, isFirstEvaluation: boolean, chatWrapper: ChatWrapper, lastEvaluationContextWindowHistory?: ChatHistoryItem[], minimumOverlapPercentageToPreventContextShift: number, sequence?: LlamaContextSequence, minFreeContextTokens?: number, functions?: ChatModelFunctions, - documentFunctionParams?: boolean + documentFunctionParams?: boolean, endWithUserText: boolean }): Promise<{ history: ChatHistoryItem[], stopGenerationTriggers: LlamaText[], tokens: Token[], newResolvedHistory: ChatHistoryItem[], newHistoryCompressionMetadata: object | null | undefined, ignoreStartText: LlamaText[], functionCallInitiallyEngaged: boolean, - disengageInitiallyEngagedFunctionCall: LlamaText[] + disengageInitiallyEngagedFunctionCall: LlamaText[], userTextSuffix?: LlamaText }> { if (sequence == null) throw new DisposedError(); @@ -617,10 +956,15 @@ async function getContextWindow({ response: [] }); - const {contextText, stopGenerationTriggers, ignoreStartText, functionCall} = chatWrapper.generateContextText(newContextWindow, { - availableFunctions: functions, - documentFunctionParams - }); + const {contextText, stopGenerationTriggers, ignoreStartText, functionCall, userTextSuffix} = generateContextText( + endWithUserText, + chatWrapper, + newContextWindow, + { + availableFunctions: functions, + documentFunctionParams + } + ); const tokens = contextText.tokenize(model.tokenizer); if (tokens.length + pendingTokensCount + minFreeContextTokens < context.contextSize) { const {firstDifferentIndex} = sequence.compareContextTokens(tokens); @@ -636,7 +980,8 @@ async function getContextWindow({ newHistoryCompressionMetadata: lastHistoryCompressionMetadata, ignoreStartText: ignoreStartText ?? [], functionCallInitiallyEngaged: functionCall?.initiallyEngaged ?? false, - disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [] + disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [], + userTextSuffix }; } } @@ -665,10 +1010,15 @@ async function getContextWindow({ documentFunctionParams }); - const {contextText, stopGenerationTriggers, ignoreStartText, functionCall} = chatWrapper.generateContextText(compressedHistory, { - availableFunctions: functions, - documentFunctionParams - }); + const {contextText, stopGenerationTriggers, ignoreStartText, functionCall, userTextSuffix} = generateContextText( + endWithUserText, + chatWrapper, + compressedHistory, + { + availableFunctions: functions, + documentFunctionParams + } + ); return { history: compressedHistory, @@ -678,15 +1028,21 @@ async function getContextWindow({ newHistoryCompressionMetadata: metadata, ignoreStartText: ignoreStartText ?? [], functionCallInitiallyEngaged: functionCall?.initiallyEngaged ?? false, - disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [] + disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [], + userTextSuffix }; } { - const {contextText, stopGenerationTriggers, ignoreStartText, functionCall} = chatWrapper.generateContextText(resolvedHistory, { - availableFunctions: functions, - documentFunctionParams - }); + const {contextText, stopGenerationTriggers, ignoreStartText, functionCall, userTextSuffix} = generateContextText( + endWithUserText, + chatWrapper, + resolvedHistory, + { + availableFunctions: functions, + documentFunctionParams + } + ); const tokens = contextText.tokenize(model.tokenizer); if (tokens.length + pendingTokensCount + minFreeContextTokens < context.contextSize) @@ -698,7 +1054,8 @@ async function getContextWindow({ newHistoryCompressionMetadata: lastHistoryCompressionMetadata, ignoreStartText: ignoreStartText ?? [], functionCallInitiallyEngaged: functionCall?.initiallyEngaged ?? false, - disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [] + disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [], + userTextSuffix }; } @@ -729,10 +1086,15 @@ async function getContextWindow({ documentFunctionParams }); - const {contextText, stopGenerationTriggers, ignoreStartText, functionCall} = chatWrapper.generateContextText(compressedHistory, { - availableFunctions: functions, - documentFunctionParams - }); + const {contextText, stopGenerationTriggers, ignoreStartText, functionCall, userTextSuffix} = generateContextText( + endWithUserText, + chatWrapper, + compressedHistory, + { + availableFunctions: functions, + documentFunctionParams + } + ); return { history: compressedHistory, @@ -742,7 +1104,8 @@ async function getContextWindow({ newHistoryCompressionMetadata: metadata, ignoreStartText: ignoreStartText ?? [], functionCallInitiallyEngaged: functionCall?.initiallyEngaged ?? false, - disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [] + disengageInitiallyEngagedFunctionCall: functionCall?.disengageInitiallyEngaged ?? [], + userTextSuffix }; } @@ -754,12 +1117,12 @@ class GenerateResponseState["onToken"]; private readonly signal: LLamaChatGenerateResponseOptions["signal"]; private readonly stopOnAbortSignal: LLamaChatGenerateResponseOptions["stopOnAbortSignal"]; - private readonly maxTokens: LLamaChatGenerateResponseOptions["maxTokens"]; + public readonly maxTokens: LLamaChatGenerateResponseOptions["maxTokens"]; private readonly temperature: LLamaChatGenerateResponseOptions["temperature"]; private readonly minP: LLamaChatGenerateResponseOptions["minP"]; private readonly topK: LLamaChatGenerateResponseOptions["topK"]; private readonly topP: LLamaChatGenerateResponseOptions["topP"]; - private readonly grammar: LLamaChatGenerateResponseOptions["grammar"]; + public readonly grammar: LLamaChatGenerateResponseOptions["grammar"]; private readonly trimWhitespaceSuffix: LLamaChatGenerateResponseOptions["trimWhitespaceSuffix"]; private readonly tokenBias: LLamaChatGenerateResponseOptions["tokenBias"]; private readonly evaluationPriority: LLamaChatGenerateResponseOptions["evaluationPriority"]; @@ -782,7 +1145,7 @@ class GenerateResponseState 0); - if (this.grammar != null && this.functionsEnabled) - throw new Error("Using both grammar and functions is not supported yet"); - if (this.signal?.aborted) throw this.signal.reason; @@ -1061,7 +1421,7 @@ class GenerateResponseState | undefined { + public handleStopGenerationTrigger() { if (this.stopGenerationDetector.hasTriggeredStops || this.customStopGenerationTriggersDetector.hasTriggeredStops || this.llamaChat.model.isEogToken(this.currentToken) ) { @@ -1510,7 +1885,7 @@ class GenerateResponseState; } return { @@ -1521,7 +1896,7 @@ class GenerateResponseState; } return undefined; @@ -1543,9 +1918,10 @@ class GenerateResponseState 0 && this.generatedTokens >= this.maxTokens; } - public moveFreePendingTokensToRes() { + public moveFreePendingTokensToRes(removeFoundStartIgnoreTextsFromPendingTokens: boolean = true) { if (this.pendingTokens.length > 0 && (this.isMaxTokensTriggered() || !this.ignoreStartTextDetector.hasInProgressStops)) { - this.removeFoundStartIgnoreTextsFromPendingTokens(); + if (removeFoundStartIgnoreTextsFromPendingTokens) + this.removeFoundStartIgnoreTextsFromPendingTokens(); if (this.pendingTokens.length > 0) { this.onToken?.(this.pendingTokens.slice()); @@ -1556,7 +1932,7 @@ class GenerateResponseState | undefined { + public handleMaxTokensTrigger() { if (this.isMaxTokensTriggered()) { let modelResponse = this.llamaChat.model.detokenize(this.res); let contextWindowModelResponse = this.llamaChat.model.detokenize(this.contextWindowsRes); @@ -1582,7 +1958,7 @@ class GenerateResponseState; } return undefined; @@ -1593,7 +1969,7 @@ class GenerateResponseState | undefined { + public handleAbortTrigger() { if (this.signal?.aborted && this.stopOnAbortSignal) { if (this.res.length === 0) throw this.signal.reason; @@ -1622,7 +1998,7 @@ class GenerateResponseState; } return undefined; diff --git a/src/evaluator/LlamaChatSession/LlamaChatSession.ts b/src/evaluator/LlamaChatSession/LlamaChatSession.ts index 1c8a2492..315ea186 100644 --- a/src/evaluator/LlamaChatSession/LlamaChatSession.ts +++ b/src/evaluator/LlamaChatSession/LlamaChatSession.ts @@ -130,6 +130,55 @@ export type LLamaChatPromptOptions { + return await withLock(this._chatLock, "evaluation", signal, async () => { this._ensureNotDisposed(); if (this._chat == null) @@ -368,7 +418,7 @@ export class LlamaChatSession { evaluationPriority, lastEvaluationContextWindow: { history: newContextWindowChatHistory, - minimumOverlapPercentageToPreventContextShift: 0.01 + minimumOverlapPercentageToPreventContextShift: 0.5 } }); this._ensureNotDisposed(); @@ -443,6 +493,125 @@ export class LlamaChatSession { }); } + /** + * Preload a user prompt into the current context sequence state to make later inference of the model response begin sooner + * and feel faster. + * + * If `maxTokens` is set to a value greater than `0`, + * a completion for the given user prompt will be generated up to the given number of tokens. + * + * > **Note:** Preloading a long user prompt and completing a user prompt with a high number of `maxTokens` can incur context shifts, + * > so consider limiting the length of prompts you preload. + * > + * > Also, it's recommended to limit the number of tokens generated to a reasonable amount. + * + * Defaults to `0`. + * @param prompt - the prompt to preload + * @param [options] + */ + public async preloadPrompt( + prompt: string, + options: LLamaChatPreloadPromptOptions & { + maxTokens?: MaxTokens + } = {} + ): Promise<0 | undefined extends MaxTokens ? void : string> { + const {completion} = await this.preloadPromptWithMeta(prompt, options); + + if (options?.maxTokens == null || options?.maxTokens === 0) + return undefined as (0 | undefined extends MaxTokens ? void : string); + + return completion as (0 | undefined extends MaxTokens ? void : string); + } + + /** + * See `preloadPrompt` for more information. + * @param prompt + * @param [options] + */ + public async preloadPromptWithMeta(prompt: string, { + maxTokens = 0, + stopOnAbortSignal = false, + + functions, + documentFunctionParams, + onToken, + signal, + temperature, + minP, + topK, + topP, + grammar, + trimWhitespaceSuffix = false, + repeatPenalty, + tokenBias, + customStopTriggers, + evaluationPriority + }: LLamaChatPreloadPromptOptions = {}) { + this._ensureNotDisposed(); + + if (grammar != null && grammar._llama !== this.model._llama) + throw new Error("The LlamaGrammar used by passed to this function was created with a different Llama instance than the one used by this sequence's model. Make sure you use the same Llama instance for both the model and the grammar."); + + return await withLock(this._chatLock, "evaluation", signal, async () => { + this._ensureNotDisposed(); + + if (this._chat == null) + throw new DisposedError(); + + const {completion, lastEvaluation, metadata} = await this._chat.loadChatAndCompleteUserMessage(this._chatHistory, { + initialUserPrompt: prompt, + functions, + documentFunctionParams, + grammar, + onToken, + signal, + stopOnAbortSignal: true, + repeatPenalty, + minP, + topK, + topP, + tokenBias, + customStopTriggers, + maxTokens, + temperature, + trimWhitespaceSuffix, + contextShift: { + ...this._contextShift, + lastEvaluationMetadata: this._lastEvaluation?.contextShiftMetadata + }, + evaluationPriority, + lastEvaluationContextWindow: { + history: this._lastEvaluation?.contextWindow, + minimumOverlapPercentageToPreventContextShift: 0.8 + } + }); + this._ensureNotDisposed(); + + this._lastEvaluation = { + cleanHistory: this._chatHistory, + contextWindow: lastEvaluation.contextWindow, + contextShiftMetadata: lastEvaluation.contextShiftMetadata + }; + + if (!stopOnAbortSignal && metadata.stopReason === "abort" && signal?.aborted) + throw signal.reason; + + if (metadata.stopReason === "customStopTrigger") + return { + completion: completion, + stopReason: metadata.stopReason, + customStopTrigger: metadata.customStopTrigger, + remainingGenerationAfterStop: metadata.remainingGenerationAfterStop + }; + + return { + completion: completion, + stopReason: metadata.stopReason, + remainingGenerationAfterStop: metadata.remainingGenerationAfterStop + }; + }); + } + public getChatHistory() { return structuredClone(this._chatHistory); } diff --git a/src/evaluator/LlamaContext/LlamaContext.ts b/src/evaluator/LlamaContext/LlamaContext.ts index d811ff7f..8a8763af 100644 --- a/src/evaluator/LlamaContext/LlamaContext.ts +++ b/src/evaluator/LlamaContext/LlamaContext.ts @@ -861,8 +861,6 @@ export class LlamaContextSequence { strategy: contextShiftStrategy = this._contextShift.strategy } = {} }: { - grammarEvaluationState?: LlamaGrammarEvaluationState, - /** * When a lot of tokens are queued for the next batch, more than the configured `batchSize`, the tokens for each sequence will be * evaluated based on the strategy chosen for the context. diff --git a/src/index.ts b/src/index.ts index 9fe6c331..f2f8f25b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -18,12 +18,12 @@ import { import {TokenBias} from "./evaluator/TokenBias.js"; import { LlamaChatSession, type LlamaChatSessionOptions, type LlamaChatSessionContextShiftOptions, - type LLamaChatPromptOptions, type LlamaChatSessionRepeatPenalty + type LLamaChatPromptOptions, type LLamaChatPreloadPromptOptions, type LlamaChatSessionRepeatPenalty } from "./evaluator/LlamaChatSession/LlamaChatSession.js"; import {defineChatSessionFunction} from "./evaluator/LlamaChatSession/utils/defineChatSessionFunction.js"; import { - LlamaChat, type LlamaChatOptions, type LLamaChatGenerateResponseOptions, type LLamaChatContextShiftOptions, - type LlamaChatResponse, type LlamaChatResponseFunctionCall + LlamaChat, type LlamaChatOptions, type LLamaChatGenerateResponseOptions, type LLamaChatLoadAndCompleteUserMessageOptions, + type LLamaChatContextShiftOptions, type LlamaChatResponse, type LlamaChatResponseFunctionCall, type LlamaChatLoadAndCompleteUserResponse } from "./evaluator/LlamaChat/LlamaChat.js"; import { LlamaCompletion, type LlamaCompletionOptions, type LlamaCompletionGenerationOptions, type LlamaInfillGenerationOptions @@ -116,14 +116,17 @@ export { type LlamaChatSessionOptions, type LlamaChatSessionContextShiftOptions, type LLamaChatPromptOptions, + type LLamaChatPreloadPromptOptions, type LlamaChatSessionRepeatPenalty, LlamaChat, type LlamaChatOptions, type LLamaChatGenerateResponseOptions, + type LLamaChatLoadAndCompleteUserMessageOptions, type LLamaChatContextShiftOptions, type LLamaContextualRepeatPenalty, type LlamaChatResponse, type LlamaChatResponseFunctionCall, + type LlamaChatLoadAndCompleteUserResponse, LlamaCompletion, type LlamaCompletionOptions, type LlamaCompletionGenerationOptions, diff --git a/test/modelDependent/llama3/chatSession.test.ts b/test/modelDependent/llama3/chatSession.test.ts index 4e545edf..b12043fd 100644 --- a/test/modelDependent/llama3/chatSession.test.ts +++ b/test/modelDependent/llama3/chatSession.test.ts @@ -63,6 +63,50 @@ describe("llama 3", () => { expect(res.responseText.toLowerCase()).to.not.include("llama"); }); + test("preloading a prompt works", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + const context = await model.createContext({ + contextSize: 2048 + }); + const chatSession = new LlamaChatSession({ + contextSequence: context.getSequence() + }); + + expect(chatSession.chatWrapper).to.be.an.instanceof(Llama3ChatWrapper); + + const prompt = "Describe the appearance of a llama"; + await chatSession.preloadPrompt(prompt); + expect(model.detokenize(chatSession.sequence.contextTokens).endsWith(prompt)).to.eql(true); + }); + + test("completing a prompt works", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + const context = await model.createContext({ + contextSize: 2048 + }); + const chatSession = new LlamaChatSession({ + contextSequence: context.getSequence() + }); + + expect(chatSession.chatWrapper).to.be.an.instanceof(Llama3ChatWrapper); + + const prompt = "Describe the appearance of a llama and explain what"; + const completion = await chatSession.preloadPrompt(prompt, { + maxTokens: 40 + }); + expect(completion).to.eql(" it is."); + }); + // disabled due to getting timeout in the CI due to taking too long test.skip("context shift works correctly", {timeout: 1000 * 60 * 60 * 2}, async () => { const contextSize = 2048; From 21629bce34464512bbd7aac529ad7a31f07a564c Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 24 May 2024 19:27:04 +0300 Subject: [PATCH 03/39] chore: remove redundant setting in script --- test/utils/setupAndTestOnPaperspace.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/utils/setupAndTestOnPaperspace.sh b/test/utils/setupAndTestOnPaperspace.sh index a3417f3e..b0804777 100644 --- a/test/utils/setupAndTestOnPaperspace.sh +++ b/test/utils/setupAndTestOnPaperspace.sh @@ -178,10 +178,10 @@ while true; do node ./dist/cli/cli.js inspect gpu echo "Running tests using CUDA..." - NODE_LLAMA_CPP_GPU=cuda NODE_LLAMA_CPP_LOG_LEVEL=warn npm run --silent test + NODE_LLAMA_CPP_GPU=cuda npm run --silent test echo "Running tests using Vulkan..." - NODE_LLAMA_CPP_GPU=vulkan NODE_LLAMA_CPP_LOG_LEVEL=warn npm run --silent test + NODE_LLAMA_CPP_GPU=vulkan npm run --silent test echo "" echo "Done running tests" From c578ddb81e034fa7d34a640b285d3d02fd646ba9 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 24 May 2024 19:27:14 +0300 Subject: [PATCH 04/39] fix: bug --- .../src/App/components/ChatHistory/ChatHistory.css | 2 ++ 1 file changed, 2 insertions(+) diff --git a/templates/electron-typescript-react/src/App/components/ChatHistory/ChatHistory.css b/templates/electron-typescript-react/src/App/components/ChatHistory/ChatHistory.css index 72194661..a6cb6e0d 100644 --- a/templates/electron-typescript-react/src/App/components/ChatHistory/ChatHistory.css +++ b/templates/electron-typescript-react/src/App/components/ChatHistory/ChatHistory.css @@ -15,6 +15,7 @@ margin-inline-start: 48px; margin-inline-end: 12px; color: var(--user-message-text-color); + white-space: pre-wrap; &:not(:first-child) { margin-top: 36px; @@ -25,6 +26,7 @@ align-self: flex-start; margin-inline-end: 48px; padding-inline-start: 24px; + white-space: pre-wrap; &.active { &:after { From 2d38a7ea9526f7798c5a4d8f707dac5777eb863b Mon Sep 17 00:00:00 2001 From: Gilad S Date: Sat, 25 May 2024 05:25:13 +0300 Subject: [PATCH 05/39] feat: prompt completion engine --- .../generic/JinjaTemplateChatWrapper.ts | 2 +- .../generic/TemplateChatWrapper.ts | 2 +- src/evaluator/LlamaChat/LlamaChat.ts | 91 ++++-- .../LlamaChatSession/LlamaChatSession.ts | 188 +++++++----- .../LlamaChatSessionPromptCompletionEngine.ts | 282 ++++++++++++++++++ src/index.ts | 9 +- src/utils/LruCache.ts | 58 ++++ src/utils/getConsoleLogPrefix.ts | 1 - src/utils/wrapAbortSignal.ts | 10 + 9 files changed, 537 insertions(+), 106 deletions(-) create mode 100644 src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts create mode 100644 src/utils/LruCache.ts create mode 100644 src/utils/wrapAbortSignal.ts diff --git a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts index c8986412..da213c4c 100644 --- a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts +++ b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts @@ -102,7 +102,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { this.trimLeadingWhitespaceInResponses = trimLeadingWhitespaceInResponses; this.settings = { - ...super.settings, + ...ChatWrapper.defaultSetting, functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSetting.functions }; diff --git a/src/chatWrappers/generic/TemplateChatWrapper.ts b/src/chatWrappers/generic/TemplateChatWrapper.ts index 6ad6d930..112250e7 100644 --- a/src/chatWrappers/generic/TemplateChatWrapper.ts +++ b/src/chatWrappers/generic/TemplateChatWrapper.ts @@ -86,7 +86,7 @@ export class TemplateChatWrapper extends ChatWrapper { this._parsedChatHistoryTemplate = parseChatHistoryTemplate(historyTemplate); this.settings = { - ...super.settings, + ...ChatWrapper.defaultSetting, functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSetting.functions }; } diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index 7d0af385..2aa54e65 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -412,7 +412,7 @@ export class LlamaChat { generateResponseState.handleInitiallyEngagedFunctionModeFunctionDetection(); generateResponseState.handleFunctionSyntax(); - const functionEndSyntaxRes = generateResponseState.detectFunctionEndSyntax(); + const functionEndSyntaxRes = generateResponseState.detectFunctionEndSyntax("model"); if (functionEndSyntaxRes != null) return functionEndSyntaxRes; @@ -421,7 +421,7 @@ export class LlamaChat { generateResponseState.popStreamRegulatorFreeTokens(); generateResponseState.removeFoundStartIgnoreTextsFromPendingTokens(); - const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger(); + const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger("model"); if (stopGenerationTriggerRes != null) return stopGenerationTriggerRes; @@ -429,14 +429,14 @@ export class LlamaChat { generateResponseState.moveFreePendingTokensToRes(); - const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger(); + const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger("model"); if (maxTokensTriggerRes != null) return maxTokensTriggerRes; if (generateResponseState.updateShouldContextShift()) break; - const abortRes = generateResponseState.handleAbortTrigger(); + const abortRes = generateResponseState.handleAbortTrigger("model"); if (abortRes != null) return abortRes; } @@ -485,6 +485,13 @@ export class LlamaChat { } = {} } = options; + const lastEvaluationContextWindowHistoryItem = lastEvaluationContextWindowHistory == null + ? null + : lastEvaluationContextWindowHistory[lastEvaluationContextWindowHistory.length - 1]; + const lastEvaluationContextWindowUserMessage = lastEvaluationContextWindowHistoryItem?.type === "user" + ? lastEvaluationContextWindowHistoryItem.text + : ""; + const generateResponseState = new GenerateResponseState( this, this._chatWrapper, @@ -508,7 +515,12 @@ export class LlamaChat { contextShift, customStopTriggers, lastEvaluationContextWindow: { - history: lastEvaluationContextWindowHistory, + history: lastEvaluationContextWindowHistory == null + ? undefined + : setLastUserTextInChatHistory( + lastEvaluationContextWindowHistory, + lastEvaluationContextWindowUserMessage + initialUserPrompt + ), minimumOverlapPercentageToPreventContextShift } } @@ -550,7 +562,7 @@ export class LlamaChat { completion: "", lastEvaluation: { contextWindow: setLastUserTextInChatHistory( - generateResponseState.contextWindowHistory, + generateResponseState.lastContextWindowHistory, initialUserMessage ), contextShiftMetadata: generateResponseState.lastHistoryCompressionMetadata @@ -569,13 +581,13 @@ export class LlamaChat { generateResponseState.popStreamRegulatorFreeTokens(); - const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger(); + const stopGenerationTriggerRes = generateResponseState.handleStopGenerationTrigger("user"); if (stopGenerationTriggerRes != null) return { completion: stopGenerationTriggerRes.response, lastEvaluation: { contextWindow: setLastUserTextInChatHistory( - generateResponseState.contextWindowHistory, + generateResponseState.lastContextWindowHistory, initialUserMessage ), contextShiftMetadata: stopGenerationTriggerRes.lastEvaluation.contextShiftMetadata @@ -587,13 +599,13 @@ export class LlamaChat { generateResponseState.moveFreePendingTokensToRes(false); - const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger(); + const maxTokensTriggerRes = generateResponseState.handleMaxTokensTrigger("user"); if (maxTokensTriggerRes != null) return { completion: maxTokensTriggerRes.response, lastEvaluation: { contextWindow: setLastUserTextInChatHistory( - generateResponseState.contextWindowHistory, + generateResponseState.lastContextWindowHistory, initialUserMessage ), contextShiftMetadata: maxTokensTriggerRes.lastEvaluation.contextShiftMetadata @@ -604,13 +616,13 @@ export class LlamaChat { if (generateResponseState.updateShouldContextShift()) break; - const abortRes = generateResponseState.handleAbortTrigger(); + const abortRes = generateResponseState.handleAbortTrigger("user"); if (abortRes != null) return { completion: abortRes.response, lastEvaluation: { contextWindow: setLastUserTextInChatHistory( - generateResponseState.contextWindowHistory, + generateResponseState.lastContextWindowHistory, initialUserMessage ), contextShiftMetadata: abortRes.lastEvaluation.contextShiftMetadata @@ -850,7 +862,7 @@ function setLastModelTextResponseInChatHistory(chatHistory: ChatHistoryItem[], t return newChatHistory; } -function setLastUserTextInChatHistory(chatHistory: ChatHistoryItem[], textResponse: string) { +function setLastUserTextInChatHistory(chatHistory: ChatHistoryItem[], userText: string) { const newChatHistory = chatHistory.slice(); if (newChatHistory.length === 0 || newChatHistory[newChatHistory.length - 1].type !== "user") newChatHistory.push({ @@ -862,11 +874,18 @@ function setLastUserTextInChatHistory(chatHistory: ChatHistoryItem[], textRespon const newLastUserItem = {...lastUserItem}; newChatHistory[newChatHistory.length - 1] = newLastUserItem; - newLastUserItem.text = textResponse; + newLastUserItem.text = userText; return newChatHistory; } +function setLastTextInChatHistory(itemType: "user" | "model", chatHistory: ChatHistoryItem[], text: string) { + if (itemType === "user") + return setLastUserTextInChatHistory(chatHistory, text); + else + return setLastModelTextResponseInChatHistory(chatHistory, text); +} + function generateContextText( endWithUserText: boolean, chatWrapper: ChatWrapper, @@ -950,7 +969,13 @@ async function getContextWindow({ if (isFirstEvaluation && lastEvaluationContextWindowHistory != null && sequence.isLoadedToMemory) { const newContextWindow = lastEvaluationContextWindowHistory.slice(); - if (newContextWindow.length === 0 || newContextWindow[newContextWindow.length - 1].type !== "model") + if (endWithUserText) { + if (newContextWindow.length === 0 || newContextWindow[newContextWindow.length - 1].type !== "user") + newContextWindow.push({ + type: "user", + text: "" + }); + } else if (newContextWindow.length === 0 || newContextWindow[newContextWindow.length - 1].type !== "model") newContextWindow.push({ type: "model", response: [] @@ -1172,7 +1197,7 @@ class GenerateResponseState | undefined { + public detectFunctionEndSyntax(lastHistoryItemType: "user" | "model"): LlamaChatResponse | undefined { if (this.inFunctionEvaluationMode && this.functionSyntaxEndDetector.hasTriggeredStops && this.functionsGrammar != null) { const functionCallText = this.llamaChat.model.detokenize(this.functionCallTokens); const functionCall = this.functionsGrammar.parseFunctionCall(functionCallText); @@ -1777,11 +1802,13 @@ class GenerateResponseState(); public readonly onDispose = new EventRelay(); @@ -366,8 +379,10 @@ export class LlamaChatSession { if (grammar != null && grammar._llama !== this.model._llama) throw new Error("The LlamaGrammar used by passed to this function was created with a different Llama instance than the one used by this sequence's model. Make sure you use the same Llama instance for both the model and the grammar."); + this._stopAllPreloadAndPromptCompletions(); return await withLock(this._chatLock, "evaluation", signal, async () => { this._ensureNotDisposed(); + this._stopAllPreloadAndPromptCompletions(); if (this._chat == null) throw new DisposedError(); @@ -468,6 +483,7 @@ export class LlamaChatSession { this._lastEvaluation = lastEvaluation; this._chatHistory = newChatHistory; + this._chatHistoryStateRef = {}; const lastModelResponseItem = getLastModelResponseItem(newChatHistory); const responseText = lastModelResponseItem.response @@ -497,39 +513,51 @@ export class LlamaChatSession { * Preload a user prompt into the current context sequence state to make later inference of the model response begin sooner * and feel faster. * - * If `maxTokens` is set to a value greater than `0`, - * a completion for the given user prompt will be generated up to the given number of tokens. + * > **Note:** Preloading a long user prompt can incur context shifts, so consider limiting the length of prompts you preload + * @param prompt - the prompt to preload + * @param [options] + */ + public async preloadPrompt(prompt: string, options: LLamaChatPreloadPromptOptions = {}): Promise { + await this.completePromptWithMeta(prompt, { + ...options, + maxTokens: 0 + }); + } + + /** + * Preload a user prompt into the current context sequence state and generate a completion for it. * * > **Note:** Preloading a long user prompt and completing a user prompt with a high number of `maxTokens` can incur context shifts, * > so consider limiting the length of prompts you preload. * > - * > Also, it's recommended to limit the number of tokens generated to a reasonable amount. - * - * Defaults to `0`. + * > Also, it's recommended to limit the number of tokens generated to a reasonable amount by configuring `maxTokens`. * @param prompt - the prompt to preload * @param [options] */ - public async preloadPrompt( - prompt: string, - options: LLamaChatPreloadPromptOptions & { - maxTokens?: MaxTokens - } = {} - ): Promise<0 | undefined extends MaxTokens ? void : string> { - const {completion} = await this.preloadPromptWithMeta(prompt, options); - - if (options?.maxTokens == null || options?.maxTokens === 0) - return undefined as (0 | undefined extends MaxTokens ? void : string); - - return completion as (0 | undefined extends MaxTokens ? void : string); + public async completePrompt(prompt: string, options: LLamaChatCompletePromptOptions = {}): Promise { + const {completion} = await this.completePromptWithMeta(prompt, options); + + return completion; + } + + /** + * Create a smart completion engine that caches the prompt completions + * and reuses them when the user prompt matches the beginning of the cached prompt or completion. + * + * All completions are made and cache is used only for the current chat session state. + * You can create a single completion engine for an entire chat session. + */ + public createPromptCompletionEngine(options?: LLamaChatPromptCompletionEngineOptions) { + return LlamaChatSessionPromptCompletionEngine._create(this, options); } /** - * See `preloadPrompt` for more information. + * See `completePrompt` for more information. * @param prompt * @param [options] */ - public async preloadPromptWithMeta(prompt: string, { - maxTokens = 0, + public async completePromptWithMeta(prompt: string, { + maxTokens, stopOnAbortSignal = false, functions, @@ -546,70 +574,77 @@ export class LlamaChatSession { tokenBias, customStopTriggers, evaluationPriority - }: LLamaChatPreloadPromptOptions = {}) { + }: LLamaChatCompletePromptOptions = {}) { this._ensureNotDisposed(); if (grammar != null && grammar._llama !== this.model._llama) throw new Error("The LlamaGrammar used by passed to this function was created with a different Llama instance than the one used by this sequence's model. Make sure you use the same Llama instance for both the model and the grammar."); - return await withLock(this._chatLock, "evaluation", signal, async () => { - this._ensureNotDisposed(); + const abortController = wrapAbortSignal(signal); + this._preloadAndCompleteAbortControllers.add(abortController); - if (this._chat == null) - throw new DisposedError(); + try { + return await withLock(this._chatLock, "evaluation", abortController.signal, async () => { + this._ensureNotDisposed(); - const {completion, lastEvaluation, metadata} = await this._chat.loadChatAndCompleteUserMessage(this._chatHistory, { - initialUserPrompt: prompt, - functions, - documentFunctionParams, - grammar, - onToken, - signal, - stopOnAbortSignal: true, - repeatPenalty, - minP, - topK, - topP, - tokenBias, - customStopTriggers, - maxTokens, - temperature, - trimWhitespaceSuffix, - contextShift: { - ...this._contextShift, - lastEvaluationMetadata: this._lastEvaluation?.contextShiftMetadata - }, - evaluationPriority, - lastEvaluationContextWindow: { - history: this._lastEvaluation?.contextWindow, - minimumOverlapPercentageToPreventContextShift: 0.8 - } - }); - this._ensureNotDisposed(); + if (this._chat == null) + throw new DisposedError(); - this._lastEvaluation = { - cleanHistory: this._chatHistory, - contextWindow: lastEvaluation.contextWindow, - contextShiftMetadata: lastEvaluation.contextShiftMetadata - }; + const {completion, lastEvaluation, metadata} = await this._chat.loadChatAndCompleteUserMessage(this._chatHistory, { + initialUserPrompt: prompt, + functions, + documentFunctionParams, + grammar, + onToken, + signal: abortController.signal, + stopOnAbortSignal: true, + repeatPenalty, + minP, + topK, + topP, + tokenBias, + customStopTriggers, + maxTokens, + temperature, + trimWhitespaceSuffix, + contextShift: { + ...this._contextShift, + lastEvaluationMetadata: this._lastEvaluation?.contextShiftMetadata + }, + evaluationPriority, + lastEvaluationContextWindow: { + history: this._lastEvaluation?.contextWindow, + minimumOverlapPercentageToPreventContextShift: 0.8 + } + }); + this._ensureNotDisposed(); - if (!stopOnAbortSignal && metadata.stopReason === "abort" && signal?.aborted) - throw signal.reason; + this._lastEvaluation = { + cleanHistory: this._chatHistory, + contextWindow: lastEvaluation.contextWindow, + contextShiftMetadata: lastEvaluation.contextShiftMetadata + }; + + if (!stopOnAbortSignal && metadata.stopReason === "abort" && abortController.signal?.aborted) + throw abortController.signal.reason; + + if (metadata.stopReason === "customStopTrigger") + return { + completion: completion, + stopReason: metadata.stopReason, + customStopTrigger: metadata.customStopTrigger, + remainingGenerationAfterStop: metadata.remainingGenerationAfterStop + }; - if (metadata.stopReason === "customStopTrigger") return { completion: completion, stopReason: metadata.stopReason, - customStopTrigger: metadata.customStopTrigger, remainingGenerationAfterStop: metadata.remainingGenerationAfterStop }; - - return { - completion: completion, - stopReason: metadata.stopReason, - remainingGenerationAfterStop: metadata.remainingGenerationAfterStop - }; - }); + }); + } finally { + this._preloadAndCompleteAbortControllers.delete(abortController); + } } public getChatHistory() { @@ -625,9 +660,18 @@ export class LlamaChatSession { public setChatHistory(chatHistory: ChatHistoryItem[]) { this._chatHistory = structuredClone(chatHistory); + this._chatHistoryStateRef = {}; this._lastEvaluation = undefined; } + /** @internal */ + private _stopAllPreloadAndPromptCompletions() { + for (const abortController of this._preloadAndCompleteAbortControllers) + abortController.abort(); + + this._preloadAndCompleteAbortControllers.clear(); + } + /** @internal */ private _ensureNotDisposed() { if (this.disposed) diff --git a/src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts b/src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts new file mode 100644 index 00000000..29d61891 --- /dev/null +++ b/src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts @@ -0,0 +1,282 @@ +import {DisposeAggregator, DisposedError} from "lifecycle-utils"; +import {Token} from "../../../types.js"; +import {getConsoleLogPrefix} from "../../../utils/getConsoleLogPrefix.js"; +import {LruCache} from "../../../utils/LruCache.js"; +import type {LLamaChatCompletePromptOptions, LlamaChatSession} from "../LlamaChatSession.js"; + +export type LLamaChatPromptCompletionEngineOptions = { + /** + * Max tokens to allow for preloading a prompt and generating a completion for it. + * + * Defaults to `256` or half of the context size, whichever is smaller. + */ + maxPreloadTokens?: number, + onGeneration?(prompt: string, completion: string): void, + + /** + * Max number of completions to cache. + * + * Defaults to `100`. + */ + maxCachedCompletions?: number, + + temperature?: LLamaChatCompletePromptOptions["temperature"], + minP?: LLamaChatCompletePromptOptions["minP"], + topK?: LLamaChatCompletePromptOptions["topK"], + topP?: LLamaChatCompletePromptOptions["topP"], + trimWhitespaceSuffix?: LLamaChatCompletePromptOptions["trimWhitespaceSuffix"], + evaluationPriority?: LLamaChatCompletePromptOptions["evaluationPriority"], + repeatPenalty?: LLamaChatCompletePromptOptions["repeatPenalty"], + tokenBias?: LLamaChatCompletePromptOptions["tokenBias"], + customStopTriggers?: LLamaChatCompletePromptOptions["customStopTriggers"], + grammar?: LLamaChatCompletePromptOptions["grammar"], + functions?: LLamaChatCompletePromptOptions["functions"], + documentFunctionParams?: LLamaChatCompletePromptOptions["documentFunctionParams"] +}; + +const defaultMaxPreloadTokens = 256; +const defaultMaxCachedCompletions = 100; + +export class LlamaChatSessionPromptCompletionEngine { + /** @internal */ private readonly _chatSession: LlamaChatSession; + /** @internal */ private readonly _maxPreloadTokens: number; + /** @internal */ private readonly _maxCachedCompletions: number; + /** @internal */ private readonly _onGeneration?: LLamaChatPromptCompletionEngineOptions["onGeneration"]; + /** @internal */ private readonly _completionOptions: LLamaChatCompletePromptOptions; + /** @internal */ private readonly _completionCaches = new WeakMap(); + /** @internal */ private readonly _disposeAggregator = new DisposeAggregator(); + /** @internal */ private _currentCompletionAbortController = new AbortController(); + /** @internal */ private _lastPrompt?: string; + /** @internal */ private _disposed = false; + + private constructor(chatSession: LlamaChatSession, { + maxPreloadTokens = defaultMaxPreloadTokens, + onGeneration, + maxCachedCompletions = defaultMaxCachedCompletions, + ...options + }: LLamaChatPromptCompletionEngineOptions) { + this._chatSession = chatSession; + this._maxPreloadTokens = Math.max(1, maxPreloadTokens); + this._maxCachedCompletions = Math.max(1, maxCachedCompletions); + this._onGeneration = onGeneration; + this._completionOptions = options; + + this.dispose = this.dispose.bind(this); + + this._disposeAggregator.add( + this._chatSession.onDispose.createListener(this.dispose) + ); + this._disposeAggregator.add(() => { + this._disposed = true; + this._currentCompletionAbortController.abort(); + }); + } + + public dispose() { + if (this._disposed) + return; + + this._disposeAggregator.dispose(); + } + + /** + * Get completion for the prompt from the cache, + * and begin preloading this prompt into the context sequence and completing it. + * + * On completion progress, `onGeneration` (configured for this engine instance) will be called. + */ + public complete(prompt: string): string { + if (this._disposed) + throw new DisposedError(); + + const completionCache = this._getCurrentCompletionCache(); + + const completion = completionCache.getCompletion(prompt); + + if (this._lastPrompt == null || !(this._lastPrompt + (completion ?? "")).startsWith(prompt)) { + this._lastPrompt = prompt; + this._restartCompletion(completionCache); + } + + this._lastPrompt = prompt; + + return completion ?? ""; + } + + /** @internal */ + private _getCurrentCompletionCache() { + const completionCache = this._completionCaches.get(this._chatSession._chatHistoryStateRef); + + if (completionCache != null) + return completionCache; + + const newCompletionCache = new CompletionCache(this._maxCachedCompletions); + this._completionCaches.set(this._chatSession._chatHistoryStateRef, newCompletionCache); + return newCompletionCache; + } + + /** @internal */ + private _restartCompletion(completionCache: CompletionCache) { + if (this._disposed) + return; + + this._currentCompletionAbortController.abort(); + this._currentCompletionAbortController = new AbortController(); + const prompt = this._lastPrompt; + + if (prompt == null) + return; + + const existingCompletion = completionCache.getCompletion(prompt); + const promptToComplete = prompt + (existingCompletion ?? ""); + + const currentPromptTokens = this._chatSession.model.tokenize(promptToComplete).length; + const leftTokens = Math.max(0, this._maxPreloadTokens - currentPromptTokens); + + if (leftTokens === 0) + return; + + const currentAbortController = this._currentCompletionAbortController; + const currentAbortSignal = this._currentCompletionAbortController.signal; + const currentCompletion: Token[] = []; + void this._chatSession.completePrompt(promptToComplete, { + ...this._completionOptions, + stopOnAbortSignal: false, + maxTokens: leftTokens, + signal: currentAbortSignal, + onToken: (chunk) => { + currentCompletion.push(...chunk); + const completion = (existingCompletion ?? "") + this._chatSession.model.detokenize(currentCompletion); + completionCache.putCompletion(prompt, completion); + + if (this._getCurrentCompletionCache() !== completionCache) { + currentAbortController.abort(); + return; + } + + try { + if (this._lastPrompt === prompt && this._onGeneration != null) + this._onGeneration(prompt, completion); + } catch (err) { + console.error(err); + } + } + }) + .then(() => { + if (this._lastPrompt !== prompt && this._getCurrentCompletionCache() === completionCache) + return this._restartCompletion(completionCache); + }) + .catch((err) => { + if (currentAbortSignal.aborted && err === currentAbortSignal.reason) + return; + + console.error(getConsoleLogPrefix(false, false), err); + }); + } + + /** @internal */ + public static _create(chatSession: LlamaChatSession, options: LLamaChatPromptCompletionEngineOptions = {}) { + return new LlamaChatSessionPromptCompletionEngine(chatSession, options); + } +} + +class CompletionCache { + /** @internal */ private readonly _cache: LruCache; + /** @internal */ private readonly _rootNode: InputNode = [new Map()]; + + public constructor(maxInputs: number) { + this._cache = new LruCache(maxInputs, { + onDelete: (key) => { + this._deleteInput(key); + } + }); + } + + public get maxInputs() { + return this._cache.maxSize; + } + + public getCompletion(input: string): string | null { + let node: InputNode | undefined = this._rootNode; + + for (let i = 0; i < input.length; i++) { + if (node == null) + return null; + + const [next, completion]: InputNode = node; + const char = input[i]; + + if (!next.has(char)) { + if (completion != null && completion.startsWith(input.slice(i))) { + this._cache.get(input.slice(0, i)); + return completion.slice(input.length - i); + } + } + + node = next.get(char); + } + + if (node == null) + return null; + + const [, possibleCompletion] = node; + if (possibleCompletion != null) { + this._cache.get(input); + return possibleCompletion; + } + + return null; + } + + public putCompletion(input: string, completion: string): string { + this._cache.set(input, null); + + let node = this._rootNode; + for (let i = 0; i < input.length; i++) { + const [next] = node; + const char = input[i]; + + if (!next.has(char)) + next.set(char, [new Map()]); + + node = next.get(char)!; + } + + const currentCompletion = node[1]; + if (currentCompletion != null && currentCompletion.startsWith(completion)) + return currentCompletion; + + node[1] = completion; + return completion; + } + + /** @internal */ + private _deleteInput(input: string) { + let lastNodeWithMultipleChildren: InputNode = this._rootNode; + let lastNodeWithMultipleChildrenDeleteChar: string = input[0]; + + let node = this._rootNode; + for (let i = 0; i < input.length; i++) { + const [next] = node; + const char = input[i]; + + if (next.size > 1) { + lastNodeWithMultipleChildren = node; + lastNodeWithMultipleChildrenDeleteChar = char; + } + + if (!next.has(char)) + return; + + node = next.get(char)!; + } + + if (lastNodeWithMultipleChildrenDeleteChar !== "") + lastNodeWithMultipleChildren[0].delete(lastNodeWithMultipleChildrenDeleteChar); + } +} + +type InputNode = [ + next: Map, + completion?: string +]; diff --git a/src/index.ts b/src/index.ts index f2f8f25b..35a8bbae 100644 --- a/src/index.ts +++ b/src/index.ts @@ -18,13 +18,16 @@ import { import {TokenBias} from "./evaluator/TokenBias.js"; import { LlamaChatSession, type LlamaChatSessionOptions, type LlamaChatSessionContextShiftOptions, - type LLamaChatPromptOptions, type LLamaChatPreloadPromptOptions, type LlamaChatSessionRepeatPenalty + type LLamaChatPromptOptions, type LLamaChatCompletePromptOptions, type LlamaChatSessionRepeatPenalty } from "./evaluator/LlamaChatSession/LlamaChatSession.js"; import {defineChatSessionFunction} from "./evaluator/LlamaChatSession/utils/defineChatSessionFunction.js"; import { LlamaChat, type LlamaChatOptions, type LLamaChatGenerateResponseOptions, type LLamaChatLoadAndCompleteUserMessageOptions, type LLamaChatContextShiftOptions, type LlamaChatResponse, type LlamaChatResponseFunctionCall, type LlamaChatLoadAndCompleteUserResponse } from "./evaluator/LlamaChat/LlamaChat.js"; +import { + LlamaChatSessionPromptCompletionEngine, type LLamaChatPromptCompletionEngineOptions +} from "./evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.js"; import { LlamaCompletion, type LlamaCompletionOptions, type LlamaCompletionGenerationOptions, type LlamaInfillGenerationOptions } from "./evaluator/LlamaCompletion.js"; @@ -116,7 +119,7 @@ export { type LlamaChatSessionOptions, type LlamaChatSessionContextShiftOptions, type LLamaChatPromptOptions, - type LLamaChatPreloadPromptOptions, + type LLamaChatCompletePromptOptions, type LlamaChatSessionRepeatPenalty, LlamaChat, type LlamaChatOptions, @@ -127,6 +130,8 @@ export { type LlamaChatResponse, type LlamaChatResponseFunctionCall, type LlamaChatLoadAndCompleteUserResponse, + LlamaChatSessionPromptCompletionEngine, + type LLamaChatPromptCompletionEngineOptions, LlamaCompletion, type LlamaCompletionOptions, type LlamaCompletionGenerationOptions, diff --git a/src/utils/LruCache.ts b/src/utils/LruCache.ts new file mode 100644 index 00000000..a77f87e0 --- /dev/null +++ b/src/utils/LruCache.ts @@ -0,0 +1,58 @@ +export class LruCache { + public readonly maxSize: number; + /** @internal */ private readonly _cache = new Map(); + /** @internal */ private readonly _onDelete?: (key: Key, value: Value) => void; + + public constructor(maxSize: number, { + onDelete + }: { + onDelete?(key: Key, value: Value): void + } = {}) { + this.maxSize = maxSize; + this._onDelete = onDelete; + } + + public get(key: Key) { + if (!this._cache.has(key)) + return undefined; + + // move the key to the end of the cache + const item = this._cache.get(key)!; + this._cache.delete(key); + this._cache.set(key, item); + return item; + } + + public set(key: Key, value: Value) { + if (this._cache.has(key)) + this._cache.delete(key); + else if (this._cache.size >= this.maxSize) { + const firstKey = this.firstKey; + + if (this._onDelete != null) + this._onDelete(firstKey, this._cache.get(firstKey)!); + + this._cache.delete(firstKey); + } + + this._cache.set(key, value); + return this; + } + + public get firstKey() { + return this._cache.keys() + .next().value; + } + + public clear() { + this._cache.clear(); + } + + public keys() { + return this._cache.keys(); + } + + public delete(key: Key) { + this._cache.delete(key); + } +} diff --git a/src/utils/getConsoleLogPrefix.ts b/src/utils/getConsoleLogPrefix.ts index 6ba1c6ab..fc8f7282 100644 --- a/src/utils/getConsoleLogPrefix.ts +++ b/src/utils/getConsoleLogPrefix.ts @@ -10,4 +10,3 @@ export function getConsoleLogPrefix(forcePrefix: boolean = false, padEnd: boolea return ""; } - diff --git a/src/utils/wrapAbortSignal.ts b/src/utils/wrapAbortSignal.ts new file mode 100644 index 00000000..cce2dac4 --- /dev/null +++ b/src/utils/wrapAbortSignal.ts @@ -0,0 +1,10 @@ +export function wrapAbortSignal(abortSignal?: AbortSignal) { + const controller = new AbortController(); + + if (abortSignal != null) + abortSignal.addEventListener("abort", () => { + controller.abort(abortSignal.reason); + }); + + return controller; +} From 2ea5265137b25444f5fe29593d1bd8ba4b33b22c Mon Sep 17 00:00:00 2001 From: Gilad S Date: Sat, 25 May 2024 05:28:16 +0300 Subject: [PATCH 06/39] feat: add prompt completion to the Electron example --- .../electron/rpc/llmRpc.ts | 1 + .../electron/state/llmState.ts | 99 ++++++++++++----- .../electron-typescript-react/src/App/App.tsx | 7 ++ .../src/App/components/Header/Header.tsx | 2 +- .../src/App/components/InputRow/InputRow.css | 96 ++++++++++++++--- .../src/App/components/InputRow/InputRow.tsx | 102 ++++++++++++++---- .../src/state/llmState.ts | 6 +- 7 files changed, 246 insertions(+), 67 deletions(-) diff --git a/templates/electron-typescript-react/electron/rpc/llmRpc.ts b/templates/electron-typescript-react/electron/rpc/llmRpc.ts index 4731a4ed..00177022 100644 --- a/templates/electron-typescript-react/electron/rpc/llmRpc.ts +++ b/templates/electron-typescript-react/electron/rpc/llmRpc.ts @@ -43,6 +43,7 @@ export class ElectronLlmRpc { getState() { return llmState.state; }, + setDraftPrompt: llmFunctions.chatSession.setDraftPrompt, prompt: llmFunctions.chatSession.prompt, stopActivePrompt: llmFunctions.chatSession.stopActivePrompt, resetChatHistory: llmFunctions.chatSession.resetChatHistory diff --git a/templates/electron-typescript-react/electron/state/llmState.ts b/templates/electron-typescript-react/electron/state/llmState.ts index f2472859..7d3e3586 100644 --- a/templates/electron-typescript-react/electron/state/llmState.ts +++ b/templates/electron-typescript-react/electron/state/llmState.ts @@ -1,5 +1,5 @@ import path from "node:path"; -import {getLlama, Llama, LlamaChatSession, LlamaContext, LlamaContextSequence, LlamaModel, Token} from "node-llama-cpp"; +import {getLlama, Llama, LlamaChatSession, LlamaChatSessionPromptCompletionEngine, LlamaContext, LlamaContextSequence, LlamaModel, Token} from "node-llama-cpp"; import {withLock, State} from "lifecycle-utils"; export const llmState = new State({ @@ -18,7 +18,11 @@ export const llmState = new State({ chatSession: { loaded: false, generatingResult: false, - simplifiedChat: [] + simplifiedChat: [], + draftPrompt: { + prompt: "", + completion: "" + } } }); @@ -45,7 +49,11 @@ export type LlmState = { chatSession: { loaded: boolean, generatingResult: boolean, - simplifiedChat: SimplifiedChatItem[] + simplifiedChat: SimplifiedChatItem[], + draftPrompt: { + prompt: string, + completion: string + } } }; @@ -60,6 +68,7 @@ let context: LlamaContext | null = null; let contextSequence: LlamaContextSequence | null = null; let chatSession: LlamaChatSession | null = null; +let chatSessionCompletionEngine: LlamaChatSessionPromptCompletionEngine | null = null; let promptAbortController: AbortController | null = null; const inProgressResponse: Token[] = []; @@ -256,6 +265,7 @@ export const llmFunctions = { try { chatSession.dispose(); chatSession = null; + chatSessionCompletionEngine = null; } catch (err) { console.error("Failed to dispose chat session", err); } @@ -267,32 +277,12 @@ export const llmFunctions = { chatSession: { loaded: false, generatingResult: false, - simplifiedChat: [] + simplifiedChat: [], + draftPrompt: llmState.state.chatSession.draftPrompt } }; - chatSession = new LlamaChatSession({ - contextSequence - }); - llmState.state = { - ...llmState.state, - chatSession: { - loaded: true, - generatingResult: false, - simplifiedChat: [] - } - }; - - chatSession.onDispose.createListener(() => { - llmState.state = { - ...llmState.state, - chatSession: { - loaded: false, - generatingResult: false, - simplifiedChat: [] - } - }; - }); + llmFunctions.chatSession.resetChatHistory(); } catch (err) { console.error("Failed to create chat session", err); llmState.state = { @@ -300,7 +290,8 @@ export const llmFunctions = { chatSession: { loaded: false, generatingResult: false, - simplifiedChat: [] + simplifiedChat: [], + draftPrompt: llmState.state.chatSession.draftPrompt } }; } @@ -359,15 +350,65 @@ export const llmFunctions = { if (contextSequence == null) return; + chatSession?.dispose(); chatSession = new LlamaChatSession({ - contextSequence + contextSequence, + autoDisposeSequence: false + }); + chatSessionCompletionEngine = chatSession.createPromptCompletionEngine({ + onGeneration(prompt, completion) { + if (llmState.state.chatSession.draftPrompt.prompt === prompt) { + llmState.state = { + ...llmState.state, + chatSession: { + ...llmState.state.chatSession, + draftPrompt: { + prompt, + completion + } + } + }; + } + } }); + llmState.state = { + ...llmState.state, + chatSession: { + loaded: true, + generatingResult: false, + simplifiedChat: [], + draftPrompt: { + prompt: llmState.state.chatSession.draftPrompt.prompt, + completion: chatSessionCompletionEngine.complete(llmState.state.chatSession.draftPrompt.prompt) + } + } + }; + + chatSession.onDispose.createListener(() => { + llmState.state = { + ...llmState.state, + chatSession: { + loaded: false, + generatingResult: false, + simplifiedChat: [], + draftPrompt: llmState.state.chatSession.draftPrompt + } + }; + }); + }, + setDraftPrompt(prompt: string) { + if (chatSessionCompletionEngine == null) + return; + llmState.state = { ...llmState.state, chatSession: { ...llmState.state.chatSession, - simplifiedChat: [] + draftPrompt: { + prompt: prompt, + completion: chatSessionCompletionEngine.complete(prompt) + } } }; } diff --git a/templates/electron-typescript-react/src/App/App.tsx b/templates/electron-typescript-react/src/App/App.tsx index a82d4aa5..a6cff083 100644 --- a/templates/electron-typescript-react/src/App/App.tsx +++ b/templates/electron-typescript-react/src/App/App.tsx @@ -67,6 +67,10 @@ export function App() { void electronLlmRpc.prompt(prompt); }, [generatingResult]); + const onPromptInput = useCallback((currentText: string) => { + void electronLlmRpc.setDraftPrompt(currentText); + }, []); + const error = state.llama.error ?? state.model.error ?? state.context.error ?? state.contextSequence.error; const showMessage = state.selectedModelFilePath == null || error != null || state.chatSession.simplifiedChat.length === 0; @@ -121,9 +125,12 @@ export function App() { ? stopActivePrompt : undefined } + onPromptInput={onPromptInput} sendPrompt={sendPrompt} generatingResult={generatingResult} contextSequenceLoaded={state.contextSequence.loaded} + autocompleteInputDraft={state.chatSession.draftPrompt.prompt} + autocompleteCompletion={state.chatSession.draftPrompt.completion} /> ; } diff --git a/templates/electron-typescript-react/src/App/components/Header/Header.tsx b/templates/electron-typescript-react/src/App/components/Header/Header.tsx index cc0ff213..17f59385 100644 --- a/templates/electron-typescript-react/src/App/components/Header/Header.tsx +++ b/templates/electron-typescript-react/src/App/components/Header/Header.tsx @@ -12,7 +12,7 @@ export function Header({modelName, onLoadClick, loadPercentage, onResetChatClick
diff --git a/templates/electron-typescript-react/src/App/components/InputRow/InputRow.css b/templates/electron-typescript-react/src/App/components/InputRow/InputRow.css index 51491f6d..4a9bb3f5 100644 --- a/templates/electron-typescript-react/src/App/components/InputRow/InputRow.css +++ b/templates/electron-typescript-react/src/App/components/InputRow/InputRow.css @@ -13,24 +13,90 @@ z-index: 10; align-items: flex-end; - > .input { + > .inputContainer { flex: 1; - border: none; - resize: none; - box-sizing: border-box; - max-height: 160px; - height: 55px; - outline: none; - padding: 12px 24px; - background-color: transparent; - font: inherit; - align-content: center; - align-self: stretch; - color: var(--panel-text-color); + display: flex; + flex-direction: row; + overflow: hidden; + position: relative; + isolation: isolate; + max-height: 400px; + min-height: var(--min-height); + --min-height: 55px; - &::placeholder { + > .input { + flex: 1; + border: none; + resize: none; + box-sizing: border-box; + max-height: 160px; + min-height: var(--min-height); + height: 55px; + outline: none; + padding: calc((var(--min-height) - 1lh) / 2) 24px; + background-color: transparent; + font: inherit; + align-content: center; + align-self: stretch; color: var(--panel-text-color); - opacity: 0.4; + z-index: 2; + unicode-bidi: plaintext; + overflow: auto; + + &::placeholder { + color: var(--panel-text-color); + opacity: 0.4; + } + } + + > .autocomplete { + position: absolute; + inset: 0px; + z-index: 1; + display: flex; + overflow: hidden; + pointer-events: none; + user-select: none; + + > .content { + flex: 1; + flex-shrink: 0; + font: inherit; + padding: calc((var(--min-height) - 1lh) / 2) 24px; + text-align: initial; + unicode-bidi: plaintext; + overflow: hidden; + opacity: 0.36; + + &.hide { + opacity: 0; + } + + > .currentText { + opacity: 0; + display: inline; + white-space: pre-wrap; + word-break: break-word; + } + + > .completion { + display: inline; + white-space: pre-wrap; + word-break: break-word; + } + + > .pressTab { + display: inline-block; + margin: -1px 8px; + opacity: 0.8; + border: solid 1px color-mix(in srgb, currentColor, transparent 64%); + border-bottom-width: 2px; + border-radius: 8px; + padding: 0.1em 0.4em; + font-size: 0.8em; + vertical-align: top; + } + } } } diff --git a/templates/electron-typescript-react/src/App/components/InputRow/InputRow.tsx b/templates/electron-typescript-react/src/App/components/InputRow/InputRow.tsx index 44dc1db5..a8ebdb37 100644 --- a/templates/electron-typescript-react/src/App/components/InputRow/InputRow.tsx +++ b/templates/electron-typescript-react/src/App/components/InputRow/InputRow.tsx @@ -1,20 +1,47 @@ -import {useCallback, useRef, useState} from "react"; +import {useCallback, useMemo, useRef, useState} from "react"; +import classNames from "classnames"; import {AddMessageIconSVG} from "../../../icons/AddMessageIconSVG.tsx"; import {AbortIconSVG} from "../../../icons/AbortIconSVG.tsx"; import "./InputRow.css"; -export function InputRow({stopGeneration, sendPrompt, generatingResult, contextSequenceLoaded}: InputRowProps) { - const [inputEmpty, setInputEmpty] = useState(true); +export function InputRow({ + stopGeneration, sendPrompt, onPromptInput, autocompleteInputDraft, autocompleteCompletion, generatingResult, contextSequenceLoaded +}: InputRowProps) { + const [inputText, setInputText] = useState(""); const inputRef = useRef(null); + const autocompleteRef = useRef(null); + const autocompleteCurrentTextRef = useRef(null); + + const autocompleteText = useMemo(() => { + const fullText = (autocompleteInputDraft ?? "") + (autocompleteCompletion ?? ""); + if (fullText.startsWith(inputText)) + return fullText.slice(inputText.length); + + return ""; + }, [inputText, autocompleteInputDraft, autocompleteCompletion]); + + const setInputValue = useCallback((value: string) => { + if (inputRef.current != null) + inputRef.current.value = value; + + if (autocompleteCurrentTextRef.current != null) + autocompleteCurrentTextRef.current.innerText = value; + + setInputText(value); + }, []); const resizeInput = useCallback(() => { if (inputRef.current == null) return; - inputRef.current.style.minHeight = ""; - inputRef.current.style.minHeight = inputRef.current.scrollHeight + "px"; + inputRef.current.style.height = ""; + inputRef.current.style.height = inputRef.current.scrollHeight + "px"; + + if (autocompleteRef.current != null) { + autocompleteRef.current.scrollTop = inputRef.current.scrollTop; + } }, []); const submitPrompt = useCallback(() => { @@ -25,35 +52,65 @@ export function InputRow({stopGeneration, sendPrompt, generatingResult, contextS if (message.length === 0) return; - inputRef.current.value = ""; + setInputValue(""); resizeInput(); sendPrompt(message); - }, [generatingResult, resizeInput, sendPrompt]); + onPromptInput?.(""); + }, [setInputValue, generatingResult, resizeInput, sendPrompt, onPromptInput]); const onInput = useCallback(() => { - setInputEmpty(inputRef.current?.value.length === 0); + setInputText(inputRef.current?.value ?? ""); resizeInput(); - }, [resizeInput]); + + if (autocompleteCurrentTextRef.current != null && inputRef.current != null) + autocompleteCurrentTextRef.current.innerText = inputRef.current?.value; + + if (inputRef.current != null && onPromptInput != null) + onPromptInput(inputRef.current?.value); + }, [resizeInput, onPromptInput]); const onInputKeyDown = useCallback((event: React.KeyboardEvent) => { if (event.key === "Enter" && !event.shiftKey) { event.preventDefault(); submitPrompt(); + resizeInput(); + } else if (event.key === "Tab" && !event.shiftKey && !event.ctrlKey && !event.metaKey && !event.altKey) { + event.preventDefault(); + if (inputRef.current != null && autocompleteText !== "") { + setInputValue(inputRef.current.value + autocompleteText); + inputRef.current.scrollTop = inputRef.current.scrollHeight; + onPromptInput?.(inputRef.current.value); + } + resizeInput(); } - }, [submitPrompt]); + }, [submitPrompt, setInputValue, onPromptInput, resizeInput, autocompleteText]); return
-