From 07ebbd06486bee11f24aaa767dfead2acb40c4a6 Mon Sep 17 00:00:00 2001 From: Gagik Amaryan Date: Mon, 9 Dec 2024 15:39:29 +0100 Subject: [PATCH] feat(participant): filter message history when it goes over maxInputTokens VSCODE-653 (#894) --- src/participant/participant.ts | 2 +- src/participant/prompts/promptBase.ts | 86 +++++++--- src/participant/prompts/promptHistory.ts | 37 +++-- src/participant/sampleDocuments.ts | 8 +- .../suite/participant/participant.test.ts | 154 ++++++++++++++---- 5 files changed, 214 insertions(+), 73 deletions(-) diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 4f336f1a2..bc5393369 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -1577,7 +1577,7 @@ export default class ParticipantController { log.info('Docs chatbot created for chatId', chatId); } - const history = PromptHistory.getFilteredHistoryForDocs({ + const history = await PromptHistory.getFilteredHistoryForDocs({ connectionNames: this._getConnectionNames(), context: context, }); diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 56ed32f67..68a3ae4e7 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -5,6 +5,7 @@ import type { ParticipantPromptProperties, } from '../../telemetry/telemetryService'; import { PromptHistory } from './promptHistory'; +import { getCopilotModel } from '../model'; import type { ParticipantCommandType } from '../participantTypes'; export interface PromptArgsBase { @@ -94,34 +95,76 @@ export function isContentEmpty( return true; } -export abstract class PromptBase { - protected abstract getAssistantPrompt(args: TArgs): string; +export abstract class PromptBase { + protected abstract getAssistantPrompt(args: PromptArgs): string; protected get internalPurposeForTelemetry(): InternalPromptPurpose { return undefined; } - protected getUserPrompt(args: TArgs): Promise { + protected getUserPrompt({ + request, + }: PromptArgs): Promise { return Promise.resolve({ - prompt: args.request.prompt, + prompt: request.prompt, hasSampleDocs: false, }); } - async buildMessages(args: TArgs): Promise { - let historyMessages = PromptHistory.getFilteredHistory({ - history: args.context?.history, - ...args, + private async _countRemainingTokens({ + model, + assistantPrompt, + requestPrompt, + }: { + model: vscode.LanguageModelChat | undefined; + assistantPrompt: vscode.LanguageModelChatMessage; + requestPrompt: string; + }): Promise { + if (model) { + const [assistantPromptTokens, userPromptTokens] = await Promise.all([ + model.countTokens(assistantPrompt), + model.countTokens(requestPrompt), + ]); + return model.maxInputTokens - (assistantPromptTokens + userPromptTokens); + } + return undefined; + } + + async buildMessages(args: PromptArgs): Promise { + const { context, request, databaseName, collectionName, connectionNames } = + args; + + const model = await getCopilotModel(); + + // eslint-disable-next-line new-cap + const assistantPrompt = vscode.LanguageModelChatMessage.Assistant( + this.getAssistantPrompt(args) + ); + + const tokenLimit = await this._countRemainingTokens({ + model, + assistantPrompt, + requestPrompt: request.prompt, + }); + + let historyMessages = await PromptHistory.getFilteredHistory({ + history: context?.history, + model, + tokenLimit, + namespaceIsKnown: + databaseName !== undefined && collectionName !== undefined, + connectionNames, }); + // If the current user's prompt is a connection name, and the last // message was to connect. We want to use the last // message they sent before the connection name as their prompt. - if (args.connectionNames?.includes(args.request.prompt)) { - const history = args.context?.history; + if (connectionNames?.includes(request.prompt)) { + const history = context?.history; if (!history) { return { messages: [], - stats: this.getStats([], args, false), + stats: this.getStats([], { request, context }, false), }; } const previousResponse = history[ @@ -132,13 +175,11 @@ export abstract class PromptBase { // Go through the history in reverse order to find the last user message. for (let i = history.length - 1; i >= 0; i--) { if (history[i] instanceof vscode.ChatRequestTurn) { + request.prompt = (history[i] as vscode.ChatRequestTurn).prompt; // Rewrite the arguments so that the prompt is the last user message from history args = { ...args, - request: { - ...args.request, - prompt: (history[i] as vscode.ChatRequestTurn).prompt, - }, + request, }; // Remove the item from the history messages array. @@ -150,23 +191,20 @@ export abstract class PromptBase { } const { prompt, hasSampleDocs } = await this.getUserPrompt(args); - const messages = [ - // eslint-disable-next-line new-cap - vscode.LanguageModelChatMessage.Assistant(this.getAssistantPrompt(args)), - ...historyMessages, - // eslint-disable-next-line new-cap - vscode.LanguageModelChatMessage.User(prompt), - ]; + // eslint-disable-next-line new-cap + const userPrompt = vscode.LanguageModelChatMessage.User(prompt); + + const messages = [assistantPrompt, ...historyMessages, userPrompt]; return { messages, - stats: this.getStats(messages, args, hasSampleDocs), + stats: this.getStats(messages, { request, context }, hasSampleDocs), }; } protected getStats( messages: vscode.LanguageModelChatMessage[], - { request, context }: TArgs, + { request, context }: Pick, hasSampleDocs: boolean ): ParticipantPromptProperties { return { diff --git a/src/participant/prompts/promptHistory.ts b/src/participant/prompts/promptHistory.ts index fa042d703..86226d71d 100644 --- a/src/participant/prompts/promptHistory.ts +++ b/src/participant/prompts/promptHistory.ts @@ -106,26 +106,28 @@ export class PromptHistory { /** When passing the history to the model we only want contextual messages to be passed. This function parses through the history and returns the messages that are valuable to keep. */ - static getFilteredHistory({ + static async getFilteredHistory({ + model, + tokenLimit, connectionNames, history, - databaseName, - collectionName, + namespaceIsKnown, }: { + model?: vscode.LanguageModelChat | undefined; + tokenLimit?: number; connectionNames?: string[]; // Used to scrape the connecting messages from the history. history?: vscode.ChatContext['history']; - databaseName?: string; - collectionName?: string; - }): vscode.LanguageModelChatMessage[] { + namespaceIsKnown: boolean; + }): Promise { const messages: vscode.LanguageModelChatMessage[] = []; if (!history) { return []; } - const namespaceIsKnown = - databaseName !== undefined && collectionName !== undefined; - for (let i = 0; i < history.length; i++) { + let totalUsedTokens = 0; + + for (let i = history.length - 1; i >= 0; i--) { const currentTurn = history[i]; let addedMessage: vscode.LanguageModelChatMessage | undefined; @@ -147,16 +149,23 @@ export class PromptHistory { }); } if (addedMessage) { + if (tokenLimit) { + totalUsedTokens += (await model?.countTokens(addedMessage)) || 0; + if (totalUsedTokens > tokenLimit) { + break; + } + } + messages.push(addedMessage); } } - return messages; + return messages.reverse(); } /** The docs chatbot keeps its own history so we avoid any * we need to include history only since last docs message. */ - static getFilteredHistoryForDocs({ + static async getFilteredHistoryForDocs({ connectionNames, context, databaseName, @@ -166,7 +175,7 @@ export class PromptHistory { context?: vscode.ChatContext; databaseName?: string; collectionName?: string; - }): vscode.LanguageModelChatMessage[] { + }): Promise { if (!context) { return []; } @@ -192,8 +201,8 @@ export class PromptHistory { return this.getFilteredHistory({ connectionNames, history: historySinceLastDocs.reverse(), - databaseName, - collectionName, + namespaceIsKnown: + databaseName !== undefined && collectionName !== undefined, }); } } diff --git a/src/participant/sampleDocuments.ts b/src/participant/sampleDocuments.ts index 4945839c0..56c4730b6 100644 --- a/src/participant/sampleDocuments.ts +++ b/src/participant/sampleDocuments.ts @@ -59,9 +59,11 @@ export async function getStringifiedSampleDocuments({ const stringifiedDocuments = toJSString(additionToPrompt); - // TODO: model.countTokens will sometimes return undefined - at least in tests. We should investigate why. - promptInputTokens = - (await model.countTokens(prompt + stringifiedDocuments)) || 0; + // Re-evaluate promptInputTokens with less documents if necessary. + if (promptInputTokens > model.maxInputTokens) { + promptInputTokens = + (await model.countTokens(prompt + stringifiedDocuments)) || 0; + } // Add sample documents to the prompt only when it fits in the context window. if (promptInputTokens <= model.maxInputTokens) { diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index ba8a710ca..ebef18ee2 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -106,8 +106,12 @@ suite('Participant Controller Test Suite', function () { button: sinon.SinonSpy; }; let chatTokenStub; - let countTokensStub; + let countTokensStub: sinon.SinonStub; let sendRequestStub: sinon.SinonStub; + let getCopilotModelStub: SinonStub< + [], + Promise + >; let telemetryTrackStub: SinonSpy; const invokeChatHandler = async ( @@ -231,21 +235,23 @@ suite('Participant Controller Test Suite', function () { chatTokenStub = { onCancellationRequested: sinon.fake(), }; - countTokensStub = sinon.stub(); + /** Resolves to 0 by default to prevent undefined being returned. + Resolve to other values to test different count limits. */ + countTokensStub = sinon.stub().resolves(0); // The model returned by vscode.lm.selectChatModels is always undefined in tests. sendRequestStub = sinon.stub(); - sinon.replace(model, 'getCopilotModel', () => - Promise.resolve({ - id: 'modelId', - vendor: 'copilot', - family: 'gpt-4o', - version: 'gpt-4o-date', - name: 'GPT 4o (date)', - maxInputTokens: MAX_TOTAL_PROMPT_LENGTH_MOCK, - countTokens: countTokensStub, - sendRequest: sendRequestStub, - }) - ); + getCopilotModelStub = sinon.stub(model, 'getCopilotModel'); + + getCopilotModelStub.resolves({ + id: 'modelId', + vendor: 'copilot', + family: 'gpt-4o', + version: 'gpt-4o-date', + name: 'GPT 4o (date)', + maxInputTokens: MAX_TOTAL_PROMPT_LENGTH_MOCK, + countTokens: countTokensStub, + sendRequest: sendRequestStub, + }); sinon.replace(testTelemetryService, 'track', telemetryTrackStub); }); @@ -832,8 +838,7 @@ suite('Participant Controller Test Suite', function () { }); test('includes 1 sample document as an object', async function () { - countTokensStub.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); - sampleStub.resolves([ + const sampleDocs = [ { _id: new ObjectId('63ed1d522d8573fa5c203660'), field: { @@ -852,13 +857,29 @@ suite('Participant Controller Test Suite', function () { ], }, }, - ]); + ]; + + // This is to offset the previous countTokens calls buildMessages gets called twice for namespace so it is adjusted accordingly + // 1. called calculating user's request prompt when buildMessages get called in _getNamespaceFromChat + // 2. called calculating assistant prompt when buildMessages get called in _getNamespaceFromChat + // 3. called calculating user's request prompt when buildMessages get called as part of the query request handling + // 4. called calculating assistant prompt when buildMessages get called as part of the query request handling + const countTokenCallsOffset = 4; + + // Called when including sample documents + countTokensStub + .onCall(countTokenCallsOffset) + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); + + sampleStub.resolves(sampleDocs); + const chatRequestMock = { prompt: 'find all docs by a name example', command: 'query', references: [], }; await invokeChatHandler(chatRequestMock); + const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.include( @@ -893,11 +914,7 @@ suite('Participant Controller Test Suite', function () { }); test('includes 1 sample documents when 3 make prompt too long', async function () { - countTokensStub - .onCall(0) - .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); - countTokensStub.onCall(1).resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); - sampleStub.resolves([ + const sampleDocs = [ { _id: new ObjectId('63ed1d522d8573fa5c203661'), field: { @@ -916,13 +933,32 @@ suite('Participant Controller Test Suite', function () { stringField: 'Text 3', }, }, - ]); + ]; + + // This is to offset the previous countTokens calls buildMessages gets called twice for namespace so it is adjusted accordingly + // 1. called calculating user's request prompt when buildMessages get called in _getNamespaceFromChat + // 2. called calculating assistant prompt when buildMessages get called in _getNamespaceFromChat + // 3. called calculating user's request prompt when buildMessages get called as part of the query request handling + // 4. called calculating assistant prompt when buildMessages get called as part of the query request handling + const countTokenCallsOffset = 4; + + // Called when including sample documents + countTokensStub + .onCall(countTokenCallsOffset) + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); + countTokensStub + .onCall(countTokenCallsOffset + 1) + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); + + sampleStub.resolves(sampleDocs); + const chatRequestMock = { prompt: 'find all docs by a name example', command: 'query', references: [], }; await invokeChatHandler(chatRequestMock); + const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.include( @@ -952,13 +988,22 @@ suite('Participant Controller Test Suite', function () { }); test('does not include sample documents when even 1 makes prompt too long', async function () { + // This is to offset the previous countTokens calls buildMessages gets called twice for namespace so it is adjusted accordingly + // 1. called calculating user's request prompt when buildMessages get called in _getNamespaceFromChat + // 2. called calculating assistant prompt when buildMessages get called in _getNamespaceFromChat + // 3. called calculating user's request prompt when buildMessages get called as part of the query request handling + // 4. called calculating assistant prompt when buildMessages get called as part of the query request handling + const countTokenCallsOffset = 4; + + // Called when including sample documents countTokensStub - .onCall(0) + .onCall(countTokenCallsOffset) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); countTokensStub - .onCall(1) + .onCall(countTokenCallsOffset + 1) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); - sampleStub.resolves([ + + const sampleDocs = [ { _id: new ObjectId('63ed1d522d8573fa5c203661'), field: { @@ -977,13 +1022,17 @@ suite('Participant Controller Test Suite', function () { stringField: 'Text 3', }, }, - ]); + ]; + + sampleStub.resolves(sampleDocs); + const chatRequestMock = { prompt: 'find all docs by a name example', command: 'query', references: [], }; await invokeChatHandler(chatRequestMock); + const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.not.include( @@ -2114,6 +2163,49 @@ Schema: }); suite('prompt builders', function () { + suite('prompt history', function () { + test('gets filtered once history goes over maxInputTokens', async function () { + const expectedMaxMessages = 10; + + const mockedMessages = Array.from( + { length: 20 }, + (_, index) => `Message ${index}` + ); + + getCopilotModelStub.resolves({ + // Make each message count as 1 token for testing + countTokens: countTokensStub.resolves(1), + maxInputTokens: expectedMaxMessages, + } as unknown as vscode.LanguageModelChat); + chatContextStub = { + history: mockedMessages.map((messageText) => + createChatRequestTurn(undefined, messageText) + ), + }; + const chatRequestMock = { + prompt: 'find all docs by a name example', + }; + const { messages } = await Prompts.generic.buildMessages({ + context: chatContextStub, + request: chatRequestMock, + connectionNames: [], + }); + + expect(messages.length).equals(expectedMaxMessages); + + // Should consist of the assistant prompt (1 token), 8 history messages (8 tokens), + // and the new request (1 token) + expect( + messages.slice(1).map((message) => getMessageContent(message)) + ).deep.equal([ + ...mockedMessages.slice( + mockedMessages.length - (expectedMaxMessages - 2) + ), + chatRequestMock.prompt, + ]); + }); + }); + test('generic', async function () { const chatRequestMock = { prompt: 'find all docs by a name example', @@ -2351,7 +2443,7 @@ Schema: vscode.LanguageModelChatMessageRole.Assistant ); - // We don't expect history because we're removing the askForConnect message as well + // We don't expect history because we're removing the askToConnect message as well // as the user response to it. Therefore the actual user prompt should be the first // message that we supplied in the history. expect(messages[1].role).to.equal( @@ -2382,7 +2474,7 @@ Schema: vscode.LanguageModelChatMessageRole.Assistant ); - // We don't expect history because we're removing the askForConnect message as well + // We don't expect history because we're removing the askToConnect message as well // as the user response to it. Therefore the actual user prompt should be the first // message that we supplied in the history. expect(messages[1].role).to.equal( @@ -2393,7 +2485,7 @@ Schema: }); }); - test('removes askForConnect messages from history', async function () { + test('removes askToConnect messages from history', async function () { // The user is responding to an `askToConnect` message, so the prompt is just the // name of the connection const chatRequestMock = { @@ -2448,7 +2540,7 @@ Schema: vscode.LanguageModelChatMessageRole.Assistant ); - // We don't expect history because we're removing the askForConnect message as well + // We don't expect history because we're removing the askToConnect message as well // as the user response to it. Therefore the actual user prompt should be the first // message that we supplied in the history. expect(messages[1].role).to.equal(