From 864dc4b2eb36772db5c599728c74455fbe5f1668 Mon Sep 17 00:00:00 2001 From: gagik Date: Tue, 3 Dec 2024 13:45:45 +0100 Subject: [PATCH 1/8] wip --- src/participant/participant.ts | 2 +- src/participant/prompts/promptBase.ts | 68 +++++++++++++------ src/participant/prompts/promptHistory.ts | 36 ++++++---- .../suite/participant/participant.test.ts | 37 ++++++++++ 4 files changed, 106 insertions(+), 37 deletions(-) diff --git a/src/participant/participant.ts b/src/participant/participant.ts index f39cec2bc..ac34707fc 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -1511,7 +1511,7 @@ export default class ParticipantController { log.info('Docs chatbot created for chatId', chatId); } - const history = PromptHistory.getFilteredHistoryForDocs({ + const history = await PromptHistory.getFilteredHistoryForDocs({ connectionNames: this._getConnectionNames(), context: context, }); diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 0df0d1b38..440cbd193 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -5,6 +5,7 @@ import type { ParticipantPromptProperties, } from '../../telemetry/telemetryService'; import { PromptHistory } from './promptHistory'; +import { getCopilotModel } from '../model'; export interface PromptArgsBase { request: { @@ -93,34 +94,64 @@ export function isContentEmpty( return true; } -export abstract class PromptBase { - protected abstract getAssistantPrompt(args: TArgs): string; +export abstract class PromptBase { + protected abstract getAssistantPrompt(args: PromptArgs): string; protected get internalPurposeForTelemetry(): InternalPromptPurpose { return undefined; } - protected getUserPrompt(args: TArgs): Promise { + protected getUserPrompt( + request: PromptArgsBase['request'] + ): Promise { return Promise.resolve({ - prompt: args.request.prompt, + prompt: request.prompt, hasSampleDocs: false, }); } - async buildMessages(args: TArgs): Promise { - let historyMessages = PromptHistory.getFilteredHistory({ - history: args.context?.history, - ...args, + async buildMessages(args: PromptArgs): Promise { + const { context, request, databaseName, collectionName, connectionNames } = + args; + + const model = await getCopilotModel(); + + // eslint-disable-next-line new-cap + const assistantPrompt = vscode.LanguageModelChatMessage.Assistant( + this.getAssistantPrompt(args) + ); + + const { prompt, hasSampleDocs } = await this.getUserPrompt(request); + // eslint-disable-next-line new-cap + const userPrompt = vscode.LanguageModelChatMessage.User(prompt); + + let tokenLimit: number | undefined; + if (model) { + const [assistantPromptTokens, userPromptTokens] = await Promise.all([ + model.countTokens(assistantPrompt), + model.countTokens(userPrompt), + ]); + tokenLimit = + model.maxInputTokens - (assistantPromptTokens + userPromptTokens); + } + + let historyMessages = await PromptHistory.getFilteredHistory({ + history: context?.history, + model, + tokenLimit, + namespaceIsKnown: + databaseName !== undefined && collectionName !== undefined, + connectionNames, }); // If the current user's prompt is a connection name, and the last // message was to connect. We want to use the last // message they sent before the connection name as their prompt. - if (args.connectionNames?.includes(args.request.prompt)) { - const history = args.context?.history; + if (connectionNames?.includes(request.prompt)) { + const history = context?.history; if (!history) { return { messages: [], - stats: this.getStats([], args, false), + stats: this.getStats([], { request, context }, false), }; } const previousResponse = history[ @@ -135,7 +166,7 @@ export abstract class PromptBase { args = { ...args, request: { - ...args.request, + ...request, prompt: (history[i] as vscode.ChatRequestTurn).prompt, }, }; @@ -148,24 +179,17 @@ export abstract class PromptBase { } } - const { prompt, hasSampleDocs } = await this.getUserPrompt(args); - const messages = [ - // eslint-disable-next-line new-cap - vscode.LanguageModelChatMessage.Assistant(this.getAssistantPrompt(args)), - ...historyMessages, - // eslint-disable-next-line new-cap - vscode.LanguageModelChatMessage.User(prompt), - ]; + const messages = [assistantPrompt, ...historyMessages, userPrompt]; return { messages, - stats: this.getStats(messages, args, hasSampleDocs), + stats: this.getStats(messages, { request, context }, hasSampleDocs), }; } protected getStats( messages: vscode.LanguageModelChatMessage[], - { request, context }: TArgs, + { request, context }: Pick, hasSampleDocs: boolean ): ParticipantPromptProperties { return { diff --git a/src/participant/prompts/promptHistory.ts b/src/participant/prompts/promptHistory.ts index 6f55e577a..76c6741c3 100644 --- a/src/participant/prompts/promptHistory.ts +++ b/src/participant/prompts/promptHistory.ts @@ -105,26 +105,28 @@ export class PromptHistory { /** When passing the history to the model we only want contextual messages to be passed. This function parses through the history and returns the messages that are valuable to keep. */ - static getFilteredHistory({ + static async getFilteredHistory({ + model, + tokenLimit, connectionNames, history, - databaseName, - collectionName, + namespaceIsKnown, }: { + model?: vscode.LanguageModelChat | undefined; + tokenLimit?: number; connectionNames?: string[]; // Used to scrape the connecting messages from the history. history?: vscode.ChatContext['history']; - databaseName?: string; - collectionName?: string; - }): vscode.LanguageModelChatMessage[] { + namespaceIsKnown: boolean; + }): Promise { const messages: vscode.LanguageModelChatMessage[] = []; if (!history) { return []; } - const namespaceIsKnown = - databaseName !== undefined && collectionName !== undefined; - for (let i = 0; i < history.length; i++) { + let totalUsedTokens = 0; + + for (let i = history.length - 1; i >= 0; i--) { const currentTurn = history[i]; let addedMessage: vscode.LanguageModelChatMessage | undefined; @@ -146,16 +148,22 @@ export class PromptHistory { }); } if (addedMessage) { + if (model && tokenLimit) { + totalUsedTokens += await model.countTokens(addedMessage); + if (totalUsedTokens > tokenLimit) { + break; + } + } messages.push(addedMessage); } } - return messages; + return messages.reverse(); } /** The docs chatbot keeps its own history so we avoid any * we need to include history only since last docs message. */ - static getFilteredHistoryForDocs({ + static async getFilteredHistoryForDocs({ connectionNames, context, databaseName, @@ -165,7 +173,7 @@ export class PromptHistory { context?: vscode.ChatContext; databaseName?: string; collectionName?: string; - }): vscode.LanguageModelChatMessage[] { + }): Promise { if (!context) { return []; } @@ -191,8 +199,8 @@ export class PromptHistory { return this.getFilteredHistory({ connectionNames, history: historySinceLastDocs.reverse(), - databaseName, - collectionName, + namespaceIsKnown: + databaseName !== undefined && collectionName !== undefined, }); } } diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index ce89bf105..cf98ba10f 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -2036,6 +2036,43 @@ Schema: }); suite('prompt builders', function () { + suite('prompt history', function () { + test('gets filtered once history goes over maxInputTokens', async function () { + const expectedMaxMessages = 8; + + const mockedMessages = Array.from( + { length: 20 }, + (_, index) => `Message ${index}` + ); + + sinon.stub(model, 'getCopilotModel').resolves({ + maxInputTokens: expectedMaxMessages, + // Make each message count as 1 token + countTokens: () => 1, + } as unknown as vscode.LanguageModelChat); + + chatContextStub = { + history: mockedMessages.map((messageText) => + createChatRequestTurn(undefined, messageText) + ), + }; + const chatRequestMock = { + prompt: 'find all docs by a name example', + }; + const { messages } = await Prompts.generic.buildMessages({ + context: chatContextStub, + request: chatRequestMock, + connectionNames: [], + }); + + // Should include the limit and the initial generic prompt and the newly sent request + expect(messages.length + 2).equals(expectedMaxMessages); + expect( + messages.slice(1).map((message) => getMessageContent(message)) + ).deep.equal([...mockedMessages, chatRequestMock.prompt]); + }); + }); + test('generic', async function () { const chatRequestMock = { prompt: 'find all docs by a name example', From d0d59bc1b7890ac12e7affe7a0e8f593af3a3936 Mon Sep 17 00:00:00 2001 From: gagik Date: Tue, 3 Dec 2024 21:48:02 +0100 Subject: [PATCH 2/8] adjust the tests --- src/participant/prompts/promptBase.ts | 8 ++++---- src/test/suite/participant/participant.test.ts | 15 +++++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 440cbd193..58d36c92f 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -101,9 +101,9 @@ export abstract class PromptBase { return undefined; } - protected getUserPrompt( - request: PromptArgsBase['request'] - ): Promise { + protected getUserPrompt({ + request, + }: PromptArgs): Promise { return Promise.resolve({ prompt: request.prompt, hasSampleDocs: false, @@ -121,7 +121,7 @@ export abstract class PromptBase { this.getAssistantPrompt(args) ); - const { prompt, hasSampleDocs } = await this.getUserPrompt(request); + const { prompt, hasSampleDocs } = await this.getUserPrompt(args); // eslint-disable-next-line new-cap const userPrompt = vscode.LanguageModelChatMessage.User(prompt); diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index cf98ba10f..f36b9db4c 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -2038,7 +2038,7 @@ Schema: suite('prompt builders', function () { suite('prompt history', function () { test('gets filtered once history goes over maxInputTokens', async function () { - const expectedMaxMessages = 8; + const expectedMaxMessages = 10; const mockedMessages = Array.from( { length: 20 }, @@ -2065,11 +2065,18 @@ Schema: connectionNames: [], }); - // Should include the limit and the initial generic prompt and the newly sent request - expect(messages.length + 2).equals(expectedMaxMessages); + expect(messages.length).equals(expectedMaxMessages); + + // Should consist of the assistant prompt (1 token), 8 history messages (8 tokens), + // and the new request (1 token) expect( messages.slice(1).map((message) => getMessageContent(message)) - ).deep.equal([...mockedMessages, chatRequestMock.prompt]); + ).deep.equal([ + ...mockedMessages.slice( + mockedMessages.length - (expectedMaxMessages - 2) + ), + chatRequestMock.prompt, + ]); }); }); From cd06aa0e2a07ba1510e133f1b294a81b11d2272c Mon Sep 17 00:00:00 2001 From: gagik Date: Wed, 4 Dec 2024 10:19:28 +0100 Subject: [PATCH 3/8] WIP --- src/participant/prompts/promptHistory.ts | 5 +- .../suite/participant/participant.test.ts | 100 ++++++++++++------ 2 files changed, 73 insertions(+), 32 deletions(-) diff --git a/src/participant/prompts/promptHistory.ts b/src/participant/prompts/promptHistory.ts index 76c6741c3..69cb7b7f7 100644 --- a/src/participant/prompts/promptHistory.ts +++ b/src/participant/prompts/promptHistory.ts @@ -148,12 +148,13 @@ export class PromptHistory { }); } if (addedMessage) { - if (model && tokenLimit) { - totalUsedTokens += await model.countTokens(addedMessage); + if (tokenLimit) { + totalUsedTokens += (await model?.countTokens(addedMessage)) || 0; if (totalUsedTokens > tokenLimit) { break; } } + messages.push(addedMessage); } } diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index f36b9db4c..2837ee9a7 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -103,8 +103,12 @@ suite('Participant Controller Test Suite', function () { button: sinon.SinonSpy; }; let chatTokenStub; - let countTokensStub; + let countTokensStub: sinon.SinonStub; let sendRequestStub: sinon.SinonStub; + let getCopilotModelStub: SinonStub< + [], + Promise + >; let telemetryTrackStub: SinonSpy; const invokeChatHandler = async ( @@ -231,18 +235,18 @@ suite('Participant Controller Test Suite', function () { countTokensStub = sinon.stub(); // The model returned by vscode.lm.selectChatModels is always undefined in tests. sendRequestStub = sinon.stub(); - sinon.replace(model, 'getCopilotModel', () => - Promise.resolve({ - id: 'modelId', - vendor: 'copilot', - family: 'gpt-4o', - version: 'gpt-4o-date', - name: 'GPT 4o (date)', - maxInputTokens: MAX_TOTAL_PROMPT_LENGTH_MOCK, - countTokens: countTokensStub, - sendRequest: sendRequestStub, - }) - ); + getCopilotModelStub = sinon.stub(model, 'getCopilotModel'); + + getCopilotModelStub.resolves({ + id: 'modelId', + vendor: 'copilot', + family: 'gpt-4o', + version: 'gpt-4o-date', + name: 'GPT 4o (date)', + maxInputTokens: MAX_TOTAL_PROMPT_LENGTH_MOCK, + countTokens: countTokensStub, + sendRequest: sendRequestStub, + }); sinon.replace(testTelemetryService, 'track', telemetryTrackStub); }); @@ -829,8 +833,7 @@ suite('Participant Controller Test Suite', function () { }); test('includes 1 sample document as an object', async function () { - countTokensStub.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); - sampleStub.resolves([ + const sampleDocs = [ { _id: new ObjectId('63ed1d522d8573fa5c203660'), field: { @@ -849,13 +852,27 @@ suite('Participant Controller Test Suite', function () { ], }, }, - ]); + ]; + + // This is the offset of the history token calculation calls + const callsOffset = 5; + + // Called when including sample documents + countTokensStub + .onCall(callsOffset) + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); + + sampleStub.resolves(sampleDocs); + const chatRequestMock = { prompt: 'find all docs by a name example', command: 'query', references: [], }; await invokeChatHandler(chatRequestMock); + + expect(countTokensStub).callCount(callsOffset + 1); + const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.include( @@ -890,11 +907,7 @@ suite('Participant Controller Test Suite', function () { }); test('includes 1 sample documents when 3 make prompt too long', async function () { - countTokensStub - .onCall(0) - .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); - countTokensStub.onCall(1).resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); - sampleStub.resolves([ + const sampleDocs = [ { _id: new ObjectId('63ed1d522d8573fa5c203661'), field: { @@ -913,13 +926,30 @@ suite('Participant Controller Test Suite', function () { stringField: 'Text 3', }, }, - ]); + ]; + + // This is the offset of the history token calculation calls + const callsOffset = 5; + + // Called when including sample documents + countTokensStub + .onCall(callsOffset) + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); + countTokensStub + .onCall(callsOffset + 1) + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); + + sampleStub.resolves(sampleDocs); + const chatRequestMock = { prompt: 'find all docs by a name example', command: 'query', references: [], }; await invokeChatHandler(chatRequestMock); + + expect(countTokensStub).callCount(callsOffset + 1); + const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.include( @@ -949,13 +979,18 @@ suite('Participant Controller Test Suite', function () { }); test('does not include sample documents when even 1 makes prompt too long', async function () { + // This is the offset of the history token calculation calls + const callsOffset = 5; + + // Called when including sample documents countTokensStub - .onCall(0) + .onCall(callsOffset) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); countTokensStub - .onCall(1) + .onCall(callsOffset + 1) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); - sampleStub.resolves([ + + const sampleDocs = [ { _id: new ObjectId('63ed1d522d8573fa5c203661'), field: { @@ -974,13 +1009,19 @@ suite('Participant Controller Test Suite', function () { stringField: 'Text 3', }, }, - ]); + ]; + + sampleStub.resolves(sampleDocs); + const chatRequestMock = { prompt: 'find all docs by a name example', command: 'query', references: [], }; await invokeChatHandler(chatRequestMock); + + expect(countTokensStub).callCount(callsOffset + 1); + const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.not.include( @@ -2045,12 +2086,11 @@ Schema: (_, index) => `Message ${index}` ); - sinon.stub(model, 'getCopilotModel').resolves({ + getCopilotModelStub.resolves({ + // Make each message count as 1 token for testing + countTokens: countTokensStub.resolves(1), maxInputTokens: expectedMaxMessages, - // Make each message count as 1 token - countTokens: () => 1, } as unknown as vscode.LanguageModelChat); - chatContextStub = { history: mockedMessages.map((messageText) => createChatRequestTurn(undefined, messageText) From 9fb4610eebdfa53376e088ac13ecca15c0a9bf7a Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 5 Dec 2024 10:38:12 +0100 Subject: [PATCH 4/8] Fix tests --- src/participant/sampleDocuments.ts | 8 ++-- .../suite/participant/participant.test.ts | 39 ++++++++++++++----- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/participant/sampleDocuments.ts b/src/participant/sampleDocuments.ts index 4945839c0..56c4730b6 100644 --- a/src/participant/sampleDocuments.ts +++ b/src/participant/sampleDocuments.ts @@ -59,9 +59,11 @@ export async function getStringifiedSampleDocuments({ const stringifiedDocuments = toJSString(additionToPrompt); - // TODO: model.countTokens will sometimes return undefined - at least in tests. We should investigate why. - promptInputTokens = - (await model.countTokens(prompt + stringifiedDocuments)) || 0; + // Re-evaluate promptInputTokens with less documents if necessary. + if (promptInputTokens > model.maxInputTokens) { + promptInputTokens = + (await model.countTokens(prompt + stringifiedDocuments)) || 0; + } // Add sample documents to the prompt only when it fits in the context window. if (promptInputTokens <= model.maxInputTokens) { diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 2837ee9a7..452288073 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -854,8 +854,9 @@ suite('Participant Controller Test Suite', function () { }, ]; - // This is the offset of the history token calculation calls - const callsOffset = 5; + // This is to offset the previous countTokens calls + // (1 for user prompt and 1 for assistant prompt calculation) + const callsOffset = 2; // Called when including sample documents countTokensStub @@ -871,7 +872,10 @@ suite('Participant Controller Test Suite', function () { }; await invokeChatHandler(chatRequestMock); - expect(countTokensStub).callCount(callsOffset + 1); + // +1 call when counting tokens of 3 sample documents + // +1 call when counting tokens in the history. + // +1 to account for zero-based indexing of the offset) + expect(countTokensStub).callCount(callsOffset + 3); const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; @@ -928,15 +932,21 @@ suite('Participant Controller Test Suite', function () { }, ]; - // This is the offset of the history token calculation calls - const callsOffset = 5; + // This is to offset the previous countTokens calls + // (1 for user prompt and 1 for assistant prompt calculation) + const callsOffset = 2; // Called when including sample documents countTokensStub .onCall(callsOffset) - .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); + .returns(Promise.resolve(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1)); countTokensStub .onCall(callsOffset + 1) + .returns(Promise.resolve(MAX_TOTAL_PROMPT_LENGTH_MOCK)); + + // Called when calculating the added finalized user prompt + countTokensStub + .onCall(callsOffset + 2) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); sampleStub.resolves(sampleDocs); @@ -948,7 +958,11 @@ suite('Participant Controller Test Suite', function () { }; await invokeChatHandler(chatRequestMock); - expect(countTokensStub).callCount(callsOffset + 1); + // +1 call when counting tokens of 3 sample documents + // +1 call for the retry with 1 sample documents + // +1 call when counting tokens in the history. + // +1 to account for zero-based indexing of the offset) + expect(countTokensStub).callCount(callsOffset + 4); const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; @@ -979,8 +993,9 @@ suite('Participant Controller Test Suite', function () { }); test('does not include sample documents when even 1 makes prompt too long', async function () { - // This is the offset of the history token calculation calls - const callsOffset = 5; + // This is to offset the previous countTokens calls + // (1 for user prompt and 1 for assistant prompt calculation) + const callsOffset = 2; // Called when including sample documents countTokensStub @@ -1020,7 +1035,11 @@ suite('Participant Controller Test Suite', function () { }; await invokeChatHandler(chatRequestMock); - expect(countTokensStub).callCount(callsOffset + 1); + // +1 call when counting tokens of 3 sample documents + // +1 call for the retry with 1 sample documents + // +1 call when counting tokens in the history. + // +1 to account for zero-based indexing of the offset) + expect(countTokensStub).callCount(callsOffset + 4); const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; From 48c48274d0334efef97e11f745e9af09b4ae5f81 Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 5 Dec 2024 12:04:16 +0100 Subject: [PATCH 5/8] adjust order and fix tests --- src/participant/prompts/promptBase.ts | 42 +++++++++++++------ .../suite/participant/participant.test.ts | 16 +++---- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 58d36c92f..cf58290fc 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -110,6 +110,25 @@ export abstract class PromptBase { }); } + private async _countRemainingTokens({ + model, + assistantPrompt, + requestPrompt, + }: { + model: vscode.LanguageModelChat | undefined; + assistantPrompt: vscode.LanguageModelChatMessage; + requestPrompt: string; + }): Promise { + if (model) { + const [assistantPromptTokens, userPromptTokens] = await Promise.all([ + model.countTokens(assistantPrompt), + model.countTokens(requestPrompt), + ]); + return model.maxInputTokens - (assistantPromptTokens + userPromptTokens); + } + return undefined; + } + async buildMessages(args: PromptArgs): Promise { const { context, request, databaseName, collectionName, connectionNames } = args; @@ -121,19 +140,11 @@ export abstract class PromptBase { this.getAssistantPrompt(args) ); - const { prompt, hasSampleDocs } = await this.getUserPrompt(args); - // eslint-disable-next-line new-cap - const userPrompt = vscode.LanguageModelChatMessage.User(prompt); - - let tokenLimit: number | undefined; - if (model) { - const [assistantPromptTokens, userPromptTokens] = await Promise.all([ - model.countTokens(assistantPrompt), - model.countTokens(userPrompt), - ]); - tokenLimit = - model.maxInputTokens - (assistantPromptTokens + userPromptTokens); - } + const tokenLimit = await this._countRemainingTokens({ + model, + assistantPrompt, + requestPrompt: request.prompt, + }); let historyMessages = await PromptHistory.getFilteredHistory({ history: context?.history, @@ -143,6 +154,7 @@ export abstract class PromptBase { databaseName !== undefined && collectionName !== undefined, connectionNames, }); + // If the current user's prompt is a connection name, and the last // message was to connect. We want to use the last // message they sent before the connection name as their prompt. @@ -179,6 +191,10 @@ export abstract class PromptBase { } } + const { prompt, hasSampleDocs } = await this.getUserPrompt(args); + // eslint-disable-next-line new-cap + const userPrompt = vscode.LanguageModelChatMessage.User(prompt); + const messages = [assistantPrompt, ...historyMessages, userPrompt]; return { diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 452288073..57cd7eb7a 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -232,7 +232,9 @@ suite('Participant Controller Test Suite', function () { chatTokenStub = { onCancellationRequested: sinon.fake(), }; - countTokensStub = sinon.stub(); + // Resolve to 0 to prevent undefined being returned + // override to other values to test different count limits. + countTokensStub = sinon.stub().resolves(0); // The model returned by vscode.lm.selectChatModels is always undefined in tests. sendRequestStub = sinon.stub(); getCopilotModelStub = sinon.stub(model, 'getCopilotModel'); @@ -939,10 +941,10 @@ suite('Participant Controller Test Suite', function () { // Called when including sample documents countTokensStub .onCall(callsOffset) - .returns(Promise.resolve(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1)); + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); countTokensStub .onCall(callsOffset + 1) - .returns(Promise.resolve(MAX_TOTAL_PROMPT_LENGTH_MOCK)); + .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); // Called when calculating the added finalized user prompt countTokensStub @@ -2376,7 +2378,7 @@ Schema: vscode.LanguageModelChatMessageRole.Assistant ); - // We don't expect history because we're removing the askForConnect message as well + // We don't expect history because we're removing the askToConnect message as well // as the user response to it. Therefore the actual user prompt should be the first // message that we supplied in the history. expect(messages[1].role).to.equal( @@ -2407,7 +2409,7 @@ Schema: vscode.LanguageModelChatMessageRole.Assistant ); - // We don't expect history because we're removing the askForConnect message as well + // We don't expect history because we're removing the askToConnect message as well // as the user response to it. Therefore the actual user prompt should be the first // message that we supplied in the history. expect(messages[1].role).to.equal( @@ -2418,7 +2420,7 @@ Schema: }); }); - test('removes askForConnect messages from history', async function () { + test('removes askToConnect messages from history', async function () { // The user is responding to an `askToConnect` message, so the prompt is just the // name of the connection const chatRequestMock = { @@ -2473,7 +2475,7 @@ Schema: vscode.LanguageModelChatMessageRole.Assistant ); - // We don't expect history because we're removing the askForConnect message as well + // We don't expect history because we're removing the askToConnect message as well // as the user response to it. Therefore the actual user prompt should be the first // message that we supplied in the history. expect(messages[1].role).to.equal( From 056970ea059f1b3a07d7e33734106f939d60b022 Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 5 Dec 2024 13:28:54 +0100 Subject: [PATCH 6/8] adjust order and fix tests --- src/participant/prompts/promptBase.ts | 6 ++-- .../suite/participant/participant.test.ts | 36 +++++-------------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index cf58290fc..19b03f7b7 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -174,13 +174,11 @@ export abstract class PromptBase { // Go through the history in reverse order to find the last user message. for (let i = history.length - 1; i >= 0; i--) { if (history[i] instanceof vscode.ChatRequestTurn) { + request.prompt = (history[i] as vscode.ChatRequestTurn).prompt; // Rewrite the arguments so that the prompt is the last user message from history args = { ...args, - request: { - ...request, - prompt: (history[i] as vscode.ChatRequestTurn).prompt, - }, + request, }; // Remove the item from the history messages array. diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 57cd7eb7a..10d043155 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -232,8 +232,8 @@ suite('Participant Controller Test Suite', function () { chatTokenStub = { onCancellationRequested: sinon.fake(), }; - // Resolve to 0 to prevent undefined being returned - // override to other values to test different count limits. + /** Resolves to 0 by default to prevent undefined being returned. + Resolve to other values to test different count limits. */ countTokensStub = sinon.stub().resolves(0); // The model returned by vscode.lm.selectChatModels is always undefined in tests. sendRequestStub = sinon.stub(); @@ -874,11 +874,6 @@ suite('Participant Controller Test Suite', function () { }; await invokeChatHandler(chatRequestMock); - // +1 call when counting tokens of 3 sample documents - // +1 call when counting tokens in the history. - // +1 to account for zero-based indexing of the offset) - expect(countTokensStub).callCount(callsOffset + 3); - const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.include( @@ -935,8 +930,9 @@ suite('Participant Controller Test Suite', function () { ]; // This is to offset the previous countTokens calls - // (1 for user prompt and 1 for assistant prompt calculation) - const callsOffset = 2; + // buildMessages gets called twice for namespace so it is adjusted accordingly + // (1 for request prompt and 1 for assistant prompt calculation) + const callsOffset = 4; // Called when including sample documents countTokensStub @@ -946,11 +942,6 @@ suite('Participant Controller Test Suite', function () { .onCall(callsOffset + 1) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); - // Called when calculating the added finalized user prompt - countTokensStub - .onCall(callsOffset + 2) - .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); - sampleStub.resolves(sampleDocs); const chatRequestMock = { @@ -960,12 +951,6 @@ suite('Participant Controller Test Suite', function () { }; await invokeChatHandler(chatRequestMock); - // +1 call when counting tokens of 3 sample documents - // +1 call for the retry with 1 sample documents - // +1 call when counting tokens in the history. - // +1 to account for zero-based indexing of the offset) - expect(countTokensStub).callCount(callsOffset + 4); - const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.include( @@ -996,8 +981,9 @@ suite('Participant Controller Test Suite', function () { test('does not include sample documents when even 1 makes prompt too long', async function () { // This is to offset the previous countTokens calls - // (1 for user prompt and 1 for assistant prompt calculation) - const callsOffset = 2; + // buildMessages gets called twice for namespace so it is adjusted accordingly + // (1 for request prompt and 1 for assistant prompt calculation) + const callsOffset = 4; // Called when including sample documents countTokensStub @@ -1037,12 +1023,6 @@ suite('Participant Controller Test Suite', function () { }; await invokeChatHandler(chatRequestMock); - // +1 call when counting tokens of 3 sample documents - // +1 call for the retry with 1 sample documents - // +1 call when counting tokens in the history. - // +1 to account for zero-based indexing of the offset) - expect(countTokensStub).callCount(callsOffset + 4); - const messages = sendRequestStub.secondCall .args[0] as vscode.LanguageModelChatMessage[]; expect(getMessageContent(messages[1])).to.not.include( From 578bdafb4e0544e07b0926f286e2c244d5bf43c1 Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 5 Dec 2024 13:30:41 +0100 Subject: [PATCH 7/8] more consistent naming and comments --- .../suite/participant/participant.test.ts | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 10d043155..3366c9387 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -857,12 +857,13 @@ suite('Participant Controller Test Suite', function () { ]; // This is to offset the previous countTokens calls - // (1 for user prompt and 1 for assistant prompt calculation) - const callsOffset = 2; + // buildMessages gets called twice for namespace so it is adjusted accordingly + // (1 for request prompt and 1 for assistant prompt calculation) + const countTokenCallsOffset = 4; // Called when including sample documents countTokensStub - .onCall(callsOffset) + .onCall(countTokenCallsOffset) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); sampleStub.resolves(sampleDocs); @@ -932,14 +933,14 @@ suite('Participant Controller Test Suite', function () { // This is to offset the previous countTokens calls // buildMessages gets called twice for namespace so it is adjusted accordingly // (1 for request prompt and 1 for assistant prompt calculation) - const callsOffset = 4; + const countTokenCallsOffset = 4; // Called when including sample documents countTokensStub - .onCall(callsOffset) + .onCall(countTokenCallsOffset) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); countTokensStub - .onCall(callsOffset + 1) + .onCall(countTokenCallsOffset + 1) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK); sampleStub.resolves(sampleDocs); @@ -983,14 +984,14 @@ suite('Participant Controller Test Suite', function () { // This is to offset the previous countTokens calls // buildMessages gets called twice for namespace so it is adjusted accordingly // (1 for request prompt and 1 for assistant prompt calculation) - const callsOffset = 4; + const countTokenCallsOffset = 4; // Called when including sample documents countTokensStub - .onCall(callsOffset) + .onCall(countTokenCallsOffset) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); countTokensStub - .onCall(callsOffset + 1) + .onCall(countTokenCallsOffset + 1) .resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1); const sampleDocs = [ From a76d24cf2add30c6e4f5163d87121044409ea1ca Mon Sep 17 00:00:00 2001 From: gagik Date: Mon, 9 Dec 2024 09:27:18 +0100 Subject: [PATCH 8/8] explain individual calls --- .../suite/participant/participant.test.ts | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 3366c9387..15d4dea48 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -856,9 +856,11 @@ suite('Participant Controller Test Suite', function () { }, ]; - // This is to offset the previous countTokens calls - // buildMessages gets called twice for namespace so it is adjusted accordingly - // (1 for request prompt and 1 for assistant prompt calculation) + // This is to offset the previous countTokens calls buildMessages gets called twice for namespace so it is adjusted accordingly + // 1. called calculating user's request prompt when buildMessages get called in _getNamespaceFromChat + // 2. called calculating assistant prompt when buildMessages get called in _getNamespaceFromChat + // 3. called calculating user's request prompt when buildMessages get called as part of the query request handling + // 4. called calculating assistant prompt when buildMessages get called as part of the query request handling const countTokenCallsOffset = 4; // Called when including sample documents @@ -930,9 +932,11 @@ suite('Participant Controller Test Suite', function () { }, ]; - // This is to offset the previous countTokens calls - // buildMessages gets called twice for namespace so it is adjusted accordingly - // (1 for request prompt and 1 for assistant prompt calculation) + // This is to offset the previous countTokens calls buildMessages gets called twice for namespace so it is adjusted accordingly + // 1. called calculating user's request prompt when buildMessages get called in _getNamespaceFromChat + // 2. called calculating assistant prompt when buildMessages get called in _getNamespaceFromChat + // 3. called calculating user's request prompt when buildMessages get called as part of the query request handling + // 4. called calculating assistant prompt when buildMessages get called as part of the query request handling const countTokenCallsOffset = 4; // Called when including sample documents @@ -981,9 +985,11 @@ suite('Participant Controller Test Suite', function () { }); test('does not include sample documents when even 1 makes prompt too long', async function () { - // This is to offset the previous countTokens calls - // buildMessages gets called twice for namespace so it is adjusted accordingly - // (1 for request prompt and 1 for assistant prompt calculation) + // This is to offset the previous countTokens calls buildMessages gets called twice for namespace so it is adjusted accordingly + // 1. called calculating user's request prompt when buildMessages get called in _getNamespaceFromChat + // 2. called calculating assistant prompt when buildMessages get called in _getNamespaceFromChat + // 3. called calculating user's request prompt when buildMessages get called as part of the query request handling + // 4. called calculating assistant prompt when buildMessages get called as part of the query request handling const countTokenCallsOffset = 4; // Called when including sample documents